1*3e777be0SXin Li //
2*3e777be0SXin Li // Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
3*3e777be0SXin Li // SPDX-License-Identifier: MIT
4*3e777be0SXin Li //
5*3e777be0SXin Li
6*3e777be0SXin Li #define LOG_TAG "ArmnnDriver"
7*3e777be0SXin Li
8*3e777be0SXin Li #include "ArmnnPreparedModel_1_2.hpp"
9*3e777be0SXin Li
10*3e777be0SXin Li #include "Utils.hpp"
11*3e777be0SXin Li
12*3e777be0SXin Li #include <armnn/Types.hpp>
13*3e777be0SXin Li
14*3e777be0SXin Li #include <log/log.h>
15*3e777be0SXin Li #include <OperationsUtils.h>
16*3e777be0SXin Li #include <ExecutionBurstServer.h>
17*3e777be0SXin Li #include <ValidateHal.h>
18*3e777be0SXin Li
19*3e777be0SXin Li #include <chrono>
20*3e777be0SXin Li #include <cinttypes>
21*3e777be0SXin Li
22*3e777be0SXin Li #ifdef ARMNN_ANDROID_S
23*3e777be0SXin Li #include <LegacyUtils.h>
24*3e777be0SXin Li #endif
25*3e777be0SXin Li
26*3e777be0SXin Li using namespace android;
27*3e777be0SXin Li using namespace android::hardware;
28*3e777be0SXin Li
29*3e777be0SXin Li namespace {
30*3e777be0SXin Li
31*3e777be0SXin Li static const V1_2::Timing g_NoTiming = {.timeOnDevice = UINT64_MAX, .timeInDriver = UINT64_MAX};
32*3e777be0SXin Li using namespace armnn_driver;
33*3e777be0SXin Li using TimePoint = std::chrono::steady_clock::time_point;
34*3e777be0SXin Li
Now()35*3e777be0SXin Li TimePoint Now()
36*3e777be0SXin Li {
37*3e777be0SXin Li return std::chrono::steady_clock::now();
38*3e777be0SXin Li }
39*3e777be0SXin Li
MicrosecondsDuration(TimePoint endPoint,TimePoint startPoint)40*3e777be0SXin Li unsigned long MicrosecondsDuration(TimePoint endPoint, TimePoint startPoint)
41*3e777be0SXin Li {
42*3e777be0SXin Li return static_cast<unsigned long>(std::chrono::duration_cast<std::chrono::microseconds>(
43*3e777be0SXin Li endPoint - startPoint).count());
44*3e777be0SXin Li }
45*3e777be0SXin Li
NotifyCallbackAndCheck(const::android::sp<V1_0::IExecutionCallback> & callback,V1_0::ErrorStatus errorStatus,std::vector<V1_2::OutputShape>,const V1_2::Timing,std::string callingFunction)46*3e777be0SXin Li void NotifyCallbackAndCheck(const ::android::sp<V1_0::IExecutionCallback>& callback,
47*3e777be0SXin Li V1_0::ErrorStatus errorStatus,
48*3e777be0SXin Li std::vector<V1_2::OutputShape>,
49*3e777be0SXin Li const V1_2::Timing,
50*3e777be0SXin Li std::string callingFunction)
51*3e777be0SXin Li {
52*3e777be0SXin Li Return<void> returned = callback->notify(errorStatus);
53*3e777be0SXin Li // This check is required, if the callback fails and it isn't checked it will bring down the service
54*3e777be0SXin Li if (!returned.isOk())
55*3e777be0SXin Li {
56*3e777be0SXin Li ALOGE("ArmnnDriver::%s: hidl callback failed to return properly: %s",
57*3e777be0SXin Li callingFunction.c_str(), returned.description().c_str());
58*3e777be0SXin Li }
59*3e777be0SXin Li }
60*3e777be0SXin Li
NotifyCallbackAndCheck(const::android::sp<V1_2::IExecutionCallback> & callback,V1_0::ErrorStatus errorStatus,std::vector<V1_2::OutputShape> outputShapes,const V1_2::Timing timing,std::string callingFunction)61*3e777be0SXin Li void NotifyCallbackAndCheck(const ::android::sp<V1_2::IExecutionCallback>& callback,
62*3e777be0SXin Li V1_0::ErrorStatus errorStatus,
63*3e777be0SXin Li std::vector<V1_2::OutputShape> outputShapes,
64*3e777be0SXin Li const V1_2::Timing timing,
65*3e777be0SXin Li std::string callingFunction)
66*3e777be0SXin Li {
67*3e777be0SXin Li Return<void> returned = callback->notify_1_2(errorStatus, outputShapes, timing);
68*3e777be0SXin Li // This check is required, if the callback fails and it isn't checked it will bring down the service
69*3e777be0SXin Li if (!returned.isOk())
70*3e777be0SXin Li {
71*3e777be0SXin Li ALOGE("ArmnnDriver::%s: hidl callback failed to return properly: %s",
72*3e777be0SXin Li callingFunction.c_str(), returned.description().c_str());
73*3e777be0SXin Li }
74*3e777be0SXin Li }
75*3e777be0SXin Li
ValidateRequestArgument(const V1_0::RequestArgument & requestArg,const armnn::TensorInfo & tensorInfo)76*3e777be0SXin Li bool ValidateRequestArgument(const V1_0::RequestArgument& requestArg, const armnn::TensorInfo& tensorInfo)
77*3e777be0SXin Li {
78*3e777be0SXin Li if (requestArg.dimensions.size() != 0)
79*3e777be0SXin Li {
80*3e777be0SXin Li if (requestArg.dimensions.size() != tensorInfo.GetNumDimensions())
81*3e777be0SXin Li {
82*3e777be0SXin Li ALOGE("Mismatched dimensions (request argument: %zu, expected: %u)",
83*3e777be0SXin Li requestArg.dimensions.size(), tensorInfo.GetNumDimensions());
84*3e777be0SXin Li return false;
85*3e777be0SXin Li }
86*3e777be0SXin Li
87*3e777be0SXin Li for (unsigned int d = 0; d < tensorInfo.GetNumDimensions(); ++d)
88*3e777be0SXin Li {
89*3e777be0SXin Li if (requestArg.dimensions[d] != 0 && requestArg.dimensions[d] != tensorInfo.GetShape()[d])
90*3e777be0SXin Li {
91*3e777be0SXin Li ALOGE("Mismatched size for dimension %d (request argument: %u, expected %u)",
92*3e777be0SXin Li d, requestArg.dimensions[d], tensorInfo.GetShape()[d]);
93*3e777be0SXin Li return false;
94*3e777be0SXin Li }
95*3e777be0SXin Li }
96*3e777be0SXin Li }
97*3e777be0SXin Li
98*3e777be0SXin Li return true;
99*3e777be0SXin Li }
100*3e777be0SXin Li
GetTensorForRequestArgument(const V1_0::RequestArgument & requestArg,const armnn::TensorInfo & tensorInfo,const std::vector<::android::nn::RunTimePoolInfo> & requestPools)101*3e777be0SXin Li armnn::Tensor GetTensorForRequestArgument(const V1_0::RequestArgument& requestArg,
102*3e777be0SXin Li const armnn::TensorInfo& tensorInfo,
103*3e777be0SXin Li const std::vector<::android::nn::RunTimePoolInfo>& requestPools)
104*3e777be0SXin Li {
105*3e777be0SXin Li if (!ValidateRequestArgument(requestArg, tensorInfo))
106*3e777be0SXin Li {
107*3e777be0SXin Li return armnn::Tensor();
108*3e777be0SXin Li }
109*3e777be0SXin Li
110*3e777be0SXin Li return armnn::Tensor(tensorInfo, GetMemoryFromPool(requestArg.location, requestPools));
111*3e777be0SXin Li }
112*3e777be0SXin Li
BuildTensorName(const char * tensorNamePrefix,std::size_t index)113*3e777be0SXin Li inline std::string BuildTensorName(const char* tensorNamePrefix, std::size_t index)
114*3e777be0SXin Li {
115*3e777be0SXin Li return tensorNamePrefix + std::to_string(index);
116*3e777be0SXin Li }
117*3e777be0SXin Li
118*3e777be0SXin Li } // anonymous namespace
119*3e777be0SXin Li
120*3e777be0SXin Li using namespace android::hardware;
121*3e777be0SXin Li
122*3e777be0SXin Li namespace armnn_driver
123*3e777be0SXin Li {
124*3e777be0SXin Li
125*3e777be0SXin Li template<typename HalVersion>
126*3e777be0SXin Li RequestThread<ArmnnPreparedModel_1_2, HalVersion, CallbackContext_1_2>
127*3e777be0SXin Li ArmnnPreparedModel_1_2<HalVersion>::m_RequestThread;
128*3e777be0SXin Li
129*3e777be0SXin Li template<typename HalVersion>
130*3e777be0SXin Li std::unique_ptr<armnn::Threadpool> ArmnnPreparedModel_1_2<HalVersion>::m_Threadpool(nullptr);
131*3e777be0SXin Li
132*3e777be0SXin Li template<typename HalVersion>
133*3e777be0SXin Li template<typename TensorBindingCollection>
DumpTensorsIfRequired(char const * tensorNamePrefix,const TensorBindingCollection & tensorBindings)134*3e777be0SXin Li void ArmnnPreparedModel_1_2<HalVersion>::DumpTensorsIfRequired(char const* tensorNamePrefix,
135*3e777be0SXin Li const TensorBindingCollection& tensorBindings)
136*3e777be0SXin Li {
137*3e777be0SXin Li if (!m_RequestInputsAndOutputsDumpDir.empty())
138*3e777be0SXin Li {
139*3e777be0SXin Li const std::string requestName = std::to_string(m_NetworkId) + "_" + std::to_string(m_RequestCount) + ".dump";
140*3e777be0SXin Li for (std::size_t i = 0u; i < tensorBindings.size(); ++i)
141*3e777be0SXin Li {
142*3e777be0SXin Li DumpTensor(m_RequestInputsAndOutputsDumpDir,
143*3e777be0SXin Li requestName,
144*3e777be0SXin Li BuildTensorName(tensorNamePrefix, i),
145*3e777be0SXin Li tensorBindings[i].second);
146*3e777be0SXin Li }
147*3e777be0SXin Li }
148*3e777be0SXin Li }
149*3e777be0SXin Li
150*3e777be0SXin Li template<typename HalVersion>
ArmnnPreparedModel_1_2(armnn::NetworkId networkId,armnn::IRuntime * runtime,const V1_2::Model & model,const std::string & requestInputsAndOutputsDumpDir,const bool gpuProfilingEnabled,const bool asyncModelExecutionEnabled,const unsigned int numberOfThreads,const bool importEnabled,const bool exportEnabled)151*3e777be0SXin Li ArmnnPreparedModel_1_2<HalVersion>::ArmnnPreparedModel_1_2(armnn::NetworkId networkId,
152*3e777be0SXin Li armnn::IRuntime* runtime,
153*3e777be0SXin Li const V1_2::Model& model,
154*3e777be0SXin Li const std::string& requestInputsAndOutputsDumpDir,
155*3e777be0SXin Li const bool gpuProfilingEnabled,
156*3e777be0SXin Li const bool asyncModelExecutionEnabled,
157*3e777be0SXin Li const unsigned int numberOfThreads,
158*3e777be0SXin Li const bool importEnabled,
159*3e777be0SXin Li const bool exportEnabled)
160*3e777be0SXin Li : m_NetworkId(networkId)
161*3e777be0SXin Li , m_Runtime(runtime)
162*3e777be0SXin Li , m_Model(model)
163*3e777be0SXin Li , m_RequestCount(0)
164*3e777be0SXin Li , m_RequestInputsAndOutputsDumpDir(requestInputsAndOutputsDumpDir)
165*3e777be0SXin Li , m_GpuProfilingEnabled(gpuProfilingEnabled)
166*3e777be0SXin Li , m_AsyncModelExecutionEnabled(asyncModelExecutionEnabled)
167*3e777be0SXin Li , m_EnableImport(importEnabled)
168*3e777be0SXin Li , m_EnableExport(exportEnabled)
169*3e777be0SXin Li , m_PreparedFromCache(false)
170*3e777be0SXin Li {
171*3e777be0SXin Li // Enable profiling if required.
172*3e777be0SXin Li m_Runtime->GetProfiler(m_NetworkId)->EnableProfiling(m_GpuProfilingEnabled);
173*3e777be0SXin Li
174*3e777be0SXin Li if (m_AsyncModelExecutionEnabled)
175*3e777be0SXin Li {
176*3e777be0SXin Li std::vector<std::shared_ptr<armnn::IWorkingMemHandle>> memHandles;
177*3e777be0SXin Li for (unsigned int i=0; i < numberOfThreads; ++i)
178*3e777be0SXin Li {
179*3e777be0SXin Li memHandles.emplace_back(m_Runtime->CreateWorkingMemHandle(networkId));
180*3e777be0SXin Li }
181*3e777be0SXin Li
182*3e777be0SXin Li if (!m_Threadpool)
183*3e777be0SXin Li {
184*3e777be0SXin Li m_Threadpool = std::make_unique<armnn::Threadpool>(numberOfThreads, runtime, memHandles);
185*3e777be0SXin Li }
186*3e777be0SXin Li else
187*3e777be0SXin Li {
188*3e777be0SXin Li m_Threadpool->LoadMemHandles(memHandles);
189*3e777be0SXin Li }
190*3e777be0SXin Li
191*3e777be0SXin Li m_WorkingMemHandle = memHandles.back();
192*3e777be0SXin Li }
193*3e777be0SXin Li }
194*3e777be0SXin Li
195*3e777be0SXin Li template<typename HalVersion>
ArmnnPreparedModel_1_2(armnn::NetworkId networkId,armnn::IRuntime * runtime,const std::string & requestInputsAndOutputsDumpDir,const bool gpuProfilingEnabled,const bool asyncModelExecutionEnabled,const unsigned int numberOfThreads,const bool importEnabled,const bool exportEnabled,const bool preparedFromCache)196*3e777be0SXin Li ArmnnPreparedModel_1_2<HalVersion>::ArmnnPreparedModel_1_2(armnn::NetworkId networkId,
197*3e777be0SXin Li armnn::IRuntime* runtime,
198*3e777be0SXin Li const std::string& requestInputsAndOutputsDumpDir,
199*3e777be0SXin Li const bool gpuProfilingEnabled,
200*3e777be0SXin Li const bool asyncModelExecutionEnabled,
201*3e777be0SXin Li const unsigned int numberOfThreads,
202*3e777be0SXin Li const bool importEnabled,
203*3e777be0SXin Li const bool exportEnabled,
204*3e777be0SXin Li const bool preparedFromCache)
205*3e777be0SXin Li : m_NetworkId(networkId)
206*3e777be0SXin Li , m_Runtime(runtime)
207*3e777be0SXin Li , m_RequestCount(0)
208*3e777be0SXin Li , m_RequestInputsAndOutputsDumpDir(requestInputsAndOutputsDumpDir)
209*3e777be0SXin Li , m_GpuProfilingEnabled(gpuProfilingEnabled)
210*3e777be0SXin Li , m_AsyncModelExecutionEnabled(asyncModelExecutionEnabled)
211*3e777be0SXin Li , m_EnableImport(importEnabled)
212*3e777be0SXin Li , m_EnableExport(exportEnabled)
213*3e777be0SXin Li , m_PreparedFromCache(preparedFromCache)
214*3e777be0SXin Li {
215*3e777be0SXin Li // Enable profiling if required.
216*3e777be0SXin Li m_Runtime->GetProfiler(m_NetworkId)->EnableProfiling(m_GpuProfilingEnabled);
217*3e777be0SXin Li
218*3e777be0SXin Li if (m_AsyncModelExecutionEnabled)
219*3e777be0SXin Li {
220*3e777be0SXin Li std::vector<std::shared_ptr<armnn::IWorkingMemHandle>> memHandles;
221*3e777be0SXin Li for (unsigned int i=0; i < numberOfThreads; ++i)
222*3e777be0SXin Li {
223*3e777be0SXin Li memHandles.emplace_back(m_Runtime->CreateWorkingMemHandle(networkId));
224*3e777be0SXin Li }
225*3e777be0SXin Li
226*3e777be0SXin Li if (!m_Threadpool)
227*3e777be0SXin Li {
228*3e777be0SXin Li m_Threadpool = std::make_unique<armnn::Threadpool>(numberOfThreads, runtime, memHandles);
229*3e777be0SXin Li }
230*3e777be0SXin Li else
231*3e777be0SXin Li {
232*3e777be0SXin Li m_Threadpool->LoadMemHandles(memHandles);
233*3e777be0SXin Li }
234*3e777be0SXin Li
235*3e777be0SXin Li m_WorkingMemHandle = memHandles.back();
236*3e777be0SXin Li }
237*3e777be0SXin Li }
238*3e777be0SXin Li
239*3e777be0SXin Li template<typename HalVersion>
~ArmnnPreparedModel_1_2()240*3e777be0SXin Li ArmnnPreparedModel_1_2<HalVersion>::~ArmnnPreparedModel_1_2()
241*3e777be0SXin Li {
242*3e777be0SXin Li // Get a hold of the profiler used by this model.
243*3e777be0SXin Li std::shared_ptr<armnn::IProfiler> profiler = m_Runtime->GetProfiler(m_NetworkId);
244*3e777be0SXin Li if (profiler && m_GpuProfilingEnabled)
245*3e777be0SXin Li {
246*3e777be0SXin Li // Dump the profiling info to a file if required.
247*3e777be0SXin Li DumpJsonProfilingIfRequired(m_GpuProfilingEnabled, m_RequestInputsAndOutputsDumpDir, m_NetworkId,
248*3e777be0SXin Li profiler.get());
249*3e777be0SXin Li }
250*3e777be0SXin Li
251*3e777be0SXin Li // Unload the network associated with this model.
252*3e777be0SXin Li m_Runtime->UnloadNetwork(m_NetworkId);
253*3e777be0SXin Li
254*3e777be0SXin Li // Unload the network memhandles from the threadpool
255*3e777be0SXin Li if (m_AsyncModelExecutionEnabled)
256*3e777be0SXin Li {
257*3e777be0SXin Li m_Threadpool->UnloadMemHandles(m_NetworkId);
258*3e777be0SXin Li }
259*3e777be0SXin Li }
260*3e777be0SXin Li
261*3e777be0SXin Li template<typename HalVersion>
execute(const V1_0::Request & request,const::android::sp<V1_0::IExecutionCallback> & callback)262*3e777be0SXin Li Return <V1_0::ErrorStatus> ArmnnPreparedModel_1_2<HalVersion>::execute(const V1_0::Request& request,
263*3e777be0SXin Li const ::android::sp<V1_0::IExecutionCallback>& callback)
264*3e777be0SXin Li {
265*3e777be0SXin Li if (callback.get() == nullptr)
266*3e777be0SXin Li {
267*3e777be0SXin Li ALOGE("ArmnnPreparedModel_1_2::execute invalid callback passed");
268*3e777be0SXin Li return V1_0::ErrorStatus::INVALID_ARGUMENT;
269*3e777be0SXin Li }
270*3e777be0SXin Li
271*3e777be0SXin Li auto cb = [callback](V1_0::ErrorStatus errorStatus,
272*3e777be0SXin Li std::vector<V1_2::OutputShape> outputShapes,
273*3e777be0SXin Li const V1_2::Timing& timing,
274*3e777be0SXin Li std::string callingFunction)
275*3e777be0SXin Li {
276*3e777be0SXin Li NotifyCallbackAndCheck(callback, errorStatus, outputShapes, timing, callingFunction);
277*3e777be0SXin Li };
278*3e777be0SXin Li
279*3e777be0SXin Li return Execute(request, V1_2::MeasureTiming::NO, cb);
280*3e777be0SXin Li }
281*3e777be0SXin Li
282*3e777be0SXin Li template<typename HalVersion>
execute_1_2(const V1_0::Request & request,V1_2::MeasureTiming measureTiming,const sp<V1_2::IExecutionCallback> & callback)283*3e777be0SXin Li Return <V1_0::ErrorStatus> ArmnnPreparedModel_1_2<HalVersion>::execute_1_2(
284*3e777be0SXin Li const V1_0::Request& request,
285*3e777be0SXin Li V1_2::MeasureTiming measureTiming,
286*3e777be0SXin Li const sp<V1_2::IExecutionCallback>& callback)
287*3e777be0SXin Li {
288*3e777be0SXin Li if (callback.get() == nullptr)
289*3e777be0SXin Li {
290*3e777be0SXin Li ALOGE("ArmnnPreparedModel_1_2::execute_1_2 invalid callback passed");
291*3e777be0SXin Li return V1_0::ErrorStatus::INVALID_ARGUMENT;
292*3e777be0SXin Li }
293*3e777be0SXin Li
294*3e777be0SXin Li auto cb = [callback](V1_0::ErrorStatus errorStatus,
295*3e777be0SXin Li std::vector<V1_2::OutputShape> outputShapes,
296*3e777be0SXin Li const V1_2::Timing& timing,
297*3e777be0SXin Li std::string callingFunction)
298*3e777be0SXin Li {
299*3e777be0SXin Li NotifyCallbackAndCheck(callback, errorStatus, outputShapes, timing, callingFunction);
300*3e777be0SXin Li };
301*3e777be0SXin Li
302*3e777be0SXin Li return Execute(request, measureTiming, cb);
303*3e777be0SXin Li }
304*3e777be0SXin Li
305*3e777be0SXin Li template<typename HalVersion>
PrepareMemoryForInputs(armnn::InputTensors & inputs,const V1_0::Request & request,const std::vector<android::nn::RunTimePoolInfo> & memPools)306*3e777be0SXin Li Return<V1_0::ErrorStatus> ArmnnPreparedModel_1_2<HalVersion>::PrepareMemoryForInputs(
307*3e777be0SXin Li armnn::InputTensors& inputs,
308*3e777be0SXin Li const V1_0::Request& request,
309*3e777be0SXin Li const std::vector<android::nn::RunTimePoolInfo>& memPools)
310*3e777be0SXin Li {
311*3e777be0SXin Li inputs.reserve(request.inputs.size());
312*3e777be0SXin Li for (unsigned int i = 0; i < request.inputs.size(); i++)
313*3e777be0SXin Li {
314*3e777be0SXin Li const auto& inputArg = request.inputs[i];
315*3e777be0SXin Li armnn::TensorInfo inputTensorInfo = m_Runtime->GetInputTensorInfo(m_NetworkId, i);
316*3e777be0SXin Li // inputs (of type InputTensors) is composed of a vector of ConstTensors.
317*3e777be0SXin Li // Therefore, set all TensorInfo isConstant parameters of input Tensors to true.
318*3e777be0SXin Li inputTensorInfo.SetConstant();
319*3e777be0SXin Li auto result = ValidateRequestArgument<V1_0::ErrorStatus, V1_0::Request>(request,
320*3e777be0SXin Li inputTensorInfo,
321*3e777be0SXin Li inputArg,
322*3e777be0SXin Li "input");
323*3e777be0SXin Li
324*3e777be0SXin Li if (result != V1_0::ErrorStatus::NONE)
325*3e777be0SXin Li {
326*3e777be0SXin Li return result;
327*3e777be0SXin Li }
328*3e777be0SXin Li
329*3e777be0SXin Li const armnn::Tensor inputTensor = GetTensorForRequestArgument(inputArg, inputTensorInfo, memPools);
330*3e777be0SXin Li
331*3e777be0SXin Li if (inputTensor.GetMemoryArea() == nullptr)
332*3e777be0SXin Li {
333*3e777be0SXin Li ALOGE("Cannot execute request. Error converting request input %u to tensor", i);
334*3e777be0SXin Li return V1_0::ErrorStatus::GENERAL_FAILURE;
335*3e777be0SXin Li }
336*3e777be0SXin Li
337*3e777be0SXin Li inputs.emplace_back(i, inputTensor);
338*3e777be0SXin Li }
339*3e777be0SXin Li
340*3e777be0SXin Li return V1_0::ErrorStatus::NONE;
341*3e777be0SXin Li }
342*3e777be0SXin Li
343*3e777be0SXin Li template<typename HalVersion>
PrepareMemoryForOutputs(armnn::OutputTensors & outputs,std::vector<V1_2::OutputShape> & outputShapes,const V1_0::Request & request,const std::vector<android::nn::RunTimePoolInfo> & memPools)344*3e777be0SXin Li Return<V1_0::ErrorStatus> ArmnnPreparedModel_1_2<HalVersion>::PrepareMemoryForOutputs(
345*3e777be0SXin Li armnn::OutputTensors& outputs,
346*3e777be0SXin Li std::vector<V1_2::OutputShape> &outputShapes,
347*3e777be0SXin Li const V1_0::Request& request,
348*3e777be0SXin Li const std::vector<android::nn::RunTimePoolInfo>& memPools)
349*3e777be0SXin Li {
350*3e777be0SXin Li outputs.reserve(request.outputs.size());
351*3e777be0SXin Li for (unsigned int i = 0; i < request.outputs.size(); i++)
352*3e777be0SXin Li {
353*3e777be0SXin Li const auto& outputArg = request.outputs[i];
354*3e777be0SXin Li armnn::TensorInfo outputTensorInfo = m_Runtime->GetOutputTensorInfo(m_NetworkId, i);
355*3e777be0SXin Li auto result = ValidateRequestArgument<V1_0::ErrorStatus, V1_0::Request>(request,
356*3e777be0SXin Li outputTensorInfo,
357*3e777be0SXin Li outputArg,
358*3e777be0SXin Li "output");
359*3e777be0SXin Li
360*3e777be0SXin Li if (result != V1_0::ErrorStatus::NONE)
361*3e777be0SXin Li {
362*3e777be0SXin Li return result;
363*3e777be0SXin Li }
364*3e777be0SXin Li
365*3e777be0SXin Li const armnn::Tensor outputTensor = GetTensorForRequestArgument(outputArg, outputTensorInfo, memPools);
366*3e777be0SXin Li if (outputTensor.GetMemoryArea() == nullptr)
367*3e777be0SXin Li {
368*3e777be0SXin Li ALOGE("Cannot execute request. Error converting request output %u to tensor", i);
369*3e777be0SXin Li return V1_0::ErrorStatus::GENERAL_FAILURE;
370*3e777be0SXin Li }
371*3e777be0SXin Li
372*3e777be0SXin Li const size_t outputSize = outputTensorInfo.GetNumBytes();
373*3e777be0SXin Li
374*3e777be0SXin Li if (outputArg.location.length < outputSize)
375*3e777be0SXin Li {
376*3e777be0SXin Li ALOGW("ArmnnPreparedModel_1_2::Execute failed: outputArg.location.length < outputSize");
377*3e777be0SXin Li return V1_0::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE;
378*3e777be0SXin Li }
379*3e777be0SXin Li
380*3e777be0SXin Li #if !defined(ARMNN_ANDROID_S)
381*3e777be0SXin Li const size_t bufferSize = memPools.at(outputArg.location.poolIndex).getHidlMemory().size();
382*3e777be0SXin Li if (bufferSize < outputSize)
383*3e777be0SXin Li {
384*3e777be0SXin Li ALOGW("ArmnnPreparedModel_1_2::Execute failed: bufferSize < outputSize");
385*3e777be0SXin Li return V1_0::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE;
386*3e777be0SXin Li }
387*3e777be0SXin Li #else
388*3e777be0SXin Li const size_t bufferSize = memPools.at(outputArg.location.poolIndex).getSize();
389*3e777be0SXin Li if (bufferSize < outputSize)
390*3e777be0SXin Li {
391*3e777be0SXin Li ALOGW("ArmnnPreparedModel_1_2::Execute failed bufferSize (%s) < outputSize (%s)",
392*3e777be0SXin Li std::to_string(bufferSize).c_str(), std::to_string(outputSize).c_str());
393*3e777be0SXin Li outputShapes[i].isSufficient = false;
394*3e777be0SXin Li return V1_0::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE;
395*3e777be0SXin Li }
396*3e777be0SXin Li #endif
397*3e777be0SXin Li outputs.emplace_back(i, outputTensor);
398*3e777be0SXin Li outputShapes[i] = ComputeShape(outputTensorInfo);
399*3e777be0SXin Li }
400*3e777be0SXin Li
401*3e777be0SXin Li return V1_0::ErrorStatus::NONE;
402*3e777be0SXin Li }
403*3e777be0SXin Li
404*3e777be0SXin Li template<typename HalVersion>
PrepareMemoryForIO(armnn::InputTensors & inputs,armnn::OutputTensors & outputs,std::vector<android::nn::RunTimePoolInfo> & memPools,const V1_0::Request & request,CallbackAsync_1_2 callback)405*3e777be0SXin Li Return<V1_0::ErrorStatus> ArmnnPreparedModel_1_2<HalVersion>::PrepareMemoryForIO(
406*3e777be0SXin Li armnn::InputTensors& inputs,
407*3e777be0SXin Li armnn::OutputTensors& outputs,
408*3e777be0SXin Li std::vector<android::nn::RunTimePoolInfo>& memPools,
409*3e777be0SXin Li const V1_0::Request& request,
410*3e777be0SXin Li CallbackAsync_1_2 callback)
411*3e777be0SXin Li {
412*3e777be0SXin Li #if !defined(ARMNN_ANDROID_S)
413*3e777be0SXin Li if (!setRunTimePoolInfosFromHidlMemories(&memPools, request.pools))
414*3e777be0SXin Li #else
415*3e777be0SXin Li if (!setRunTimePoolInfosFromCanonicalMemories(&memPools, uncheckedConvert(request.pools)))
416*3e777be0SXin Li #endif
417*3e777be0SXin Li {
418*3e777be0SXin Li callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute");
419*3e777be0SXin Li return V1_0::ErrorStatus::GENERAL_FAILURE;
420*3e777be0SXin Li }
421*3e777be0SXin Li // add the inputs and outputs with their data
422*3e777be0SXin Li try
423*3e777be0SXin Li {
424*3e777be0SXin Li if (PrepareMemoryForInputs(inputs, request, memPools) != V1_0::ErrorStatus::NONE)
425*3e777be0SXin Li {
426*3e777be0SXin Li callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute");
427*3e777be0SXin Li return V1_0::ErrorStatus::GENERAL_FAILURE;
428*3e777be0SXin Li }
429*3e777be0SXin Li
430*3e777be0SXin Li std::vector<V1_2::OutputShape> outputShapes(request.outputs.size());
431*3e777be0SXin Li
432*3e777be0SXin Li auto errorStatus = PrepareMemoryForOutputs(outputs, outputShapes, request, memPools);
433*3e777be0SXin Li if (errorStatus != V1_0::ErrorStatus::NONE)
434*3e777be0SXin Li {
435*3e777be0SXin Li callback(errorStatus,
436*3e777be0SXin Li outputShapes,
437*3e777be0SXin Li g_NoTiming,
438*3e777be0SXin Li "ArmnnPreparedModel_1_2::Execute");
439*3e777be0SXin Li return errorStatus;
440*3e777be0SXin Li }
441*3e777be0SXin Li }
442*3e777be0SXin Li catch (armnn::Exception& e)
443*3e777be0SXin Li {
444*3e777be0SXin Li ALOGW("armnn::Exception caught while preparing for EnqueueWorkload: %s", e.what());
445*3e777be0SXin Li callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute");
446*3e777be0SXin Li return V1_0::ErrorStatus::GENERAL_FAILURE;
447*3e777be0SXin Li }
448*3e777be0SXin Li catch (std::exception& e)
449*3e777be0SXin Li {
450*3e777be0SXin Li ALOGE("std::exception caught while preparing for EnqueueWorkload: %s", e.what());
451*3e777be0SXin Li callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute");
452*3e777be0SXin Li return V1_0::ErrorStatus::GENERAL_FAILURE;
453*3e777be0SXin Li }
454*3e777be0SXin Li
455*3e777be0SXin Li return V1_0::ErrorStatus::NONE;
456*3e777be0SXin Li }
457*3e777be0SXin Li
458*3e777be0SXin Li template<typename HalVersion>
executeSynchronously(const V1_0::Request & request,V1_2::MeasureTiming measureTiming,executeSynchronously_cb cb)459*3e777be0SXin Li Return<void> ArmnnPreparedModel_1_2<HalVersion>::executeSynchronously(const V1_0::Request& request,
460*3e777be0SXin Li V1_2::MeasureTiming measureTiming,
461*3e777be0SXin Li executeSynchronously_cb cb)
462*3e777be0SXin Li {
463*3e777be0SXin Li if (!m_PreparedFromCache)
464*3e777be0SXin Li {
465*3e777be0SXin Li ALOGV("ArmnnPreparedModel_1_2::executeSynchronously(): %s", GetModelSummary(m_Model).c_str());
466*3e777be0SXin Li }
467*3e777be0SXin Li m_RequestCount++;
468*3e777be0SXin Li
469*3e777be0SXin Li if (cb == nullptr)
470*3e777be0SXin Li {
471*3e777be0SXin Li ALOGE("ArmnnPreparedModel_1_2::executeSynchronously invalid callback passed");
472*3e777be0SXin Li return Void();
473*3e777be0SXin Li }
474*3e777be0SXin Li
475*3e777be0SXin Li TimePoint driverStart;
476*3e777be0SXin Li
477*3e777be0SXin Li if (measureTiming == V1_2::MeasureTiming::YES)
478*3e777be0SXin Li {
479*3e777be0SXin Li driverStart = Now();
480*3e777be0SXin Li }
481*3e777be0SXin Li
482*3e777be0SXin Li if (!m_PreparedFromCache && !android::nn::validateRequest(request, m_Model))
483*3e777be0SXin Li {
484*3e777be0SXin Li ALOGE("ArmnnPreparedModel_1_2::executeSynchronously invalid request model");
485*3e777be0SXin Li cb(V1_0::ErrorStatus::INVALID_ARGUMENT, {}, g_NoTiming);
486*3e777be0SXin Li return Void();
487*3e777be0SXin Li }
488*3e777be0SXin Li
489*3e777be0SXin Li auto cbWrapper = [cb](V1_0::ErrorStatus errorStatus,
490*3e777be0SXin Li std::vector<V1_2::OutputShape> outputShapes,
491*3e777be0SXin Li const V1_2::Timing& timing,
492*3e777be0SXin Li std::string)
493*3e777be0SXin Li {
494*3e777be0SXin Li cb(errorStatus, outputShapes, timing);
495*3e777be0SXin Li };
496*3e777be0SXin Li
497*3e777be0SXin Li // map the memory pool into shared pointers
498*3e777be0SXin Li // use a shared memory pools vector on the heap, as it is passed to the request thread
499*3e777be0SXin Li auto memPools = std::make_shared<std::vector<android::nn::RunTimePoolInfo>>();
500*3e777be0SXin Li
501*3e777be0SXin Li // allocate the tensors on the heap, as they are passed to the request thread
502*3e777be0SXin Li auto inputs = std::make_shared<armnn::InputTensors>();
503*3e777be0SXin Li auto outputs = std::make_shared<armnn::OutputTensors>();
504*3e777be0SXin Li
505*3e777be0SXin Li auto prepareStatus = PrepareMemoryForIO(*inputs, *outputs, *memPools, request, cbWrapper);
506*3e777be0SXin Li if (prepareStatus != V1_0::ErrorStatus::NONE)
507*3e777be0SXin Li {
508*3e777be0SXin Li return Void();
509*3e777be0SXin Li }
510*3e777be0SXin Li
511*3e777be0SXin Li ALOGV("ArmnnPreparedModel_1_2::executeSynchronously() before Execution");
512*3e777be0SXin Li
513*3e777be0SXin Li CallbackContext_1_2 cbCtx;
514*3e777be0SXin Li cbCtx.callback = cbWrapper;
515*3e777be0SXin Li cbCtx.ctx.measureTimings = measureTiming;
516*3e777be0SXin Li cbCtx.ctx.driverStart = driverStart;
517*3e777be0SXin Li ExecuteGraph(memPools, *inputs, *outputs, cbCtx);
518*3e777be0SXin Li
519*3e777be0SXin Li return Void();
520*3e777be0SXin Li }
521*3e777be0SXin Li
522*3e777be0SXin Li template<typename HalVersion>
523*3e777be0SXin Li template<typename CallbackContext>
ExecuteGraph(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> & pMemPools,armnn::InputTensors & inputTensors,armnn::OutputTensors & outputTensors,CallbackContext cb)524*3e777be0SXin Li bool ArmnnPreparedModel_1_2<HalVersion>::ExecuteGraph(
525*3e777be0SXin Li std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
526*3e777be0SXin Li armnn::InputTensors& inputTensors,
527*3e777be0SXin Li armnn::OutputTensors& outputTensors,
528*3e777be0SXin Li CallbackContext cb)
529*3e777be0SXin Li {
530*3e777be0SXin Li ALOGV("ArmnnPreparedModel_1_2::ExecuteGraph(...)");
531*3e777be0SXin Li
532*3e777be0SXin Li TimePoint driverEnd, deviceStart, deviceEnd;
533*3e777be0SXin Li // Capture the graph execution start time.
534*3e777be0SXin Li std::chrono::time_point<std::chrono::system_clock> graphExecutionStart = std::chrono::system_clock::now();
535*3e777be0SXin Li
536*3e777be0SXin Li DumpTensorsIfRequired("Input", inputTensors);
537*3e777be0SXin Li
538*3e777be0SXin Li std::vector<V1_2::OutputShape> outputShapes(outputTensors.size());
539*3e777be0SXin Li for (unsigned int i = 0; i < outputTensors.size(); i++)
540*3e777be0SXin Li {
541*3e777be0SXin Li std::pair<int, armnn::Tensor> outputTensorPair = outputTensors[i];
542*3e777be0SXin Li const armnn::Tensor outputTensor = outputTensorPair.second;
543*3e777be0SXin Li const armnn::TensorInfo outputTensorInfo = outputTensor.GetInfo();
544*3e777be0SXin Li
545*3e777be0SXin Li outputShapes[i] = ComputeShape(outputTensorInfo);
546*3e777be0SXin Li }
547*3e777be0SXin Li
548*3e777be0SXin Li // run it
549*3e777be0SXin Li try
550*3e777be0SXin Li {
551*3e777be0SXin Li if (cb.ctx.measureTimings == V1_2::MeasureTiming::YES)
552*3e777be0SXin Li {
553*3e777be0SXin Li deviceStart = Now();
554*3e777be0SXin Li }
555*3e777be0SXin Li
556*3e777be0SXin Li armnn::Status status;
557*3e777be0SXin Li if (m_AsyncModelExecutionEnabled)
558*3e777be0SXin Li {
559*3e777be0SXin Li ALOGW("ArmnnPreparedModel_1_2::ExecuteGraph m_AsyncModelExecutionEnabled true");
560*3e777be0SXin Li status = m_Runtime->Execute(*m_WorkingMemHandle, inputTensors, outputTensors);
561*3e777be0SXin Li }
562*3e777be0SXin Li else
563*3e777be0SXin Li {
564*3e777be0SXin Li ALOGW("ArmnnPreparedModel_1_2::ExecuteGraph m_AsyncModelExecutionEnabled false");
565*3e777be0SXin Li
566*3e777be0SXin Li // Create a vector of Input and Output Ids which can be imported. An empty vector means all will be copied.
567*3e777be0SXin Li std::vector<armnn::ImportedInputId> importedInputIds;
568*3e777be0SXin Li if (m_EnableImport)
569*3e777be0SXin Li {
570*3e777be0SXin Li importedInputIds = m_Runtime->ImportInputs(m_NetworkId, inputTensors, armnn::MemorySource::Malloc);
571*3e777be0SXin Li }
572*3e777be0SXin Li std::vector<armnn::ImportedOutputId> importedOutputIds;
573*3e777be0SXin Li if (m_EnableExport)
574*3e777be0SXin Li {
575*3e777be0SXin Li importedOutputIds = m_Runtime->ImportOutputs(m_NetworkId, outputTensors, armnn::MemorySource::Malloc);
576*3e777be0SXin Li }
577*3e777be0SXin Li status = m_Runtime->EnqueueWorkload(m_NetworkId, inputTensors, outputTensors,
578*3e777be0SXin Li importedInputIds, importedOutputIds);
579*3e777be0SXin Li }
580*3e777be0SXin Li
581*3e777be0SXin Li if (cb.ctx.measureTimings == V1_2::MeasureTiming::YES)
582*3e777be0SXin Li {
583*3e777be0SXin Li deviceEnd = Now();
584*3e777be0SXin Li }
585*3e777be0SXin Li if (status != armnn::Status::Success)
586*3e777be0SXin Li {
587*3e777be0SXin Li ALOGW("EnqueueWorkload failed");
588*3e777be0SXin Li cb.callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming,
589*3e777be0SXin Li "ArmnnPreparedModel_1_2::ExecuteGraph");
590*3e777be0SXin Li return false;
591*3e777be0SXin Li }
592*3e777be0SXin Li }
593*3e777be0SXin Li catch (armnn::Exception& e)
594*3e777be0SXin Li {
595*3e777be0SXin Li ALOGW("armnn:Exception caught from EnqueueWorkload: %s", e.what());
596*3e777be0SXin Li cb.callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::ExecuteGraph");
597*3e777be0SXin Li return false;
598*3e777be0SXin Li }
599*3e777be0SXin Li catch (std::exception& e)
600*3e777be0SXin Li {
601*3e777be0SXin Li ALOGE("std::exception caught from EnqueueWorkload: %s", e.what());
602*3e777be0SXin Li cb.callback(V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel_1_2::ExecuteGraph");
603*3e777be0SXin Li return false;
604*3e777be0SXin Li }
605*3e777be0SXin Li
606*3e777be0SXin Li CommitPools(*pMemPools);
607*3e777be0SXin Li
608*3e777be0SXin Li DumpTensorsIfRequired("Output", outputTensors);
609*3e777be0SXin Li
610*3e777be0SXin Li if (cb.ctx.measureTimings == V1_2::MeasureTiming::YES)
611*3e777be0SXin Li {
612*3e777be0SXin Li driverEnd = Now();
613*3e777be0SXin Li V1_2::Timing timing;
614*3e777be0SXin Li timing.timeOnDevice = MicrosecondsDuration(deviceEnd, deviceStart);
615*3e777be0SXin Li timing.timeInDriver = MicrosecondsDuration(driverEnd, cb.ctx.driverStart);
616*3e777be0SXin Li ALOGV("ArmnnPreparedModel_1_2::execute timing - Device = %lu Driver = %lu",
617*3e777be0SXin Li static_cast<unsigned long>(timing.timeOnDevice), static_cast<unsigned long>(timing.timeInDriver));
618*3e777be0SXin Li cb.callback(V1_0::ErrorStatus::NONE, outputShapes, timing, "ArmnnPreparedModel_1_2::ExecuteGraph");
619*3e777be0SXin Li } else {
620*3e777be0SXin Li cb.callback(V1_0::ErrorStatus::NONE, outputShapes, g_NoTiming, "ArmnnPreparedModel_1_2::ExecuteGraph");
621*3e777be0SXin Li }
622*3e777be0SXin Li
623*3e777be0SXin Li // Log the total time in this call. This is a good number to compare to that printed out by
624*3e777be0SXin Li // RuntimeImpl::EnqueueWorkload. The difference should be the execution overhead of the driver.
625*3e777be0SXin Li ALOGI("ArmnnPreparedModel_1_2::ExecuteGraph Execution time = %lld µs",
626*3e777be0SXin Li std::chrono::duration_cast<std::chrono::microseconds>
627*3e777be0SXin Li (std::chrono::system_clock::now() - graphExecutionStart).count());
628*3e777be0SXin Li return true;
629*3e777be0SXin Li }
630*3e777be0SXin Li
631*3e777be0SXin Li template<typename HalVersion>
ExecuteWithDummyInputs(unsigned int numInputs,unsigned int numOutputs)632*3e777be0SXin Li bool ArmnnPreparedModel_1_2<HalVersion>::ExecuteWithDummyInputs(unsigned int numInputs, unsigned int numOutputs)
633*3e777be0SXin Li {
634*3e777be0SXin Li std::vector<std::vector<char>> storage;
635*3e777be0SXin Li armnn::InputTensors inputTensors;
636*3e777be0SXin Li for (unsigned int i = 0; i < numInputs; i++)
637*3e777be0SXin Li {
638*3e777be0SXin Li armnn::TensorInfo inputTensorInfo = m_Runtime->GetInputTensorInfo(m_NetworkId, i);
639*3e777be0SXin Li // pInputTensors (of type InputTensors) is composed of a vector of ConstTensors.
640*3e777be0SXin Li // Therefore, set all TensorInfo isConstant parameters of input Tensors to true.
641*3e777be0SXin Li inputTensorInfo.SetConstant();
642*3e777be0SXin Li
643*3e777be0SXin Li storage.emplace_back(inputTensorInfo.GetNumBytes());
644*3e777be0SXin Li const armnn::ConstTensor inputTensor(inputTensorInfo, storage.back().data());
645*3e777be0SXin Li
646*3e777be0SXin Li inputTensors.emplace_back(i, inputTensor);
647*3e777be0SXin Li }
648*3e777be0SXin Li
649*3e777be0SXin Li armnn::OutputTensors outputTensors;
650*3e777be0SXin Li for (unsigned int i = 0; i < numOutputs; i++)
651*3e777be0SXin Li {
652*3e777be0SXin Li const armnn::TensorInfo outputTensorInfo = m_Runtime->GetOutputTensorInfo(m_NetworkId, i);
653*3e777be0SXin Li storage.emplace_back(outputTensorInfo.GetNumBytes());
654*3e777be0SXin Li const armnn::Tensor outputTensor(outputTensorInfo, storage.back().data());
655*3e777be0SXin Li
656*3e777be0SXin Li outputTensors.emplace_back(i, outputTensor);
657*3e777be0SXin Li }
658*3e777be0SXin Li
659*3e777be0SXin Li auto nullCallback = [](V1_0::ErrorStatus, std::vector<V1_2::OutputShape>, const V1_2::Timing&, std::string) {};
660*3e777be0SXin Li CallbackContext_1_2 callbackContext;
661*3e777be0SXin Li callbackContext.callback = nullCallback;
662*3e777be0SXin Li callbackContext.ctx.measureTimings = V1_2::MeasureTiming::NO;
663*3e777be0SXin Li auto memPools = std::make_shared<std::vector<::android::nn::RunTimePoolInfo>>();
664*3e777be0SXin Li return ExecuteGraph(memPools,
665*3e777be0SXin Li inputTensors,
666*3e777be0SXin Li outputTensors,
667*3e777be0SXin Li callbackContext);
668*3e777be0SXin Li }
669*3e777be0SXin Li
670*3e777be0SXin Li template<typename HalVersion>
Execute(const V1_0::Request & request,V1_2::MeasureTiming measureTiming,CallbackAsync_1_2 callback)671*3e777be0SXin Li Return <V1_0::ErrorStatus> ArmnnPreparedModel_1_2<HalVersion>::Execute(const V1_0::Request& request,
672*3e777be0SXin Li V1_2::MeasureTiming measureTiming,
673*3e777be0SXin Li CallbackAsync_1_2 callback)
674*3e777be0SXin Li {
675*3e777be0SXin Li ExecutionContext_1_2 ctx;
676*3e777be0SXin Li if (measureTiming == V1_2::MeasureTiming::YES)
677*3e777be0SXin Li {
678*3e777be0SXin Li ctx.measureTimings = measureTiming;
679*3e777be0SXin Li ctx.driverStart = Now();
680*3e777be0SXin Li }
681*3e777be0SXin Li
682*3e777be0SXin Li if (!m_PreparedFromCache)
683*3e777be0SXin Li {
684*3e777be0SXin Li ALOGV("ArmnnPreparedModel_1_2::execute(): %s", GetModelSummary(m_Model).c_str());
685*3e777be0SXin Li }
686*3e777be0SXin Li m_RequestCount++;
687*3e777be0SXin Li
688*3e777be0SXin Li if (!m_PreparedFromCache && !android::nn::validateRequest(request, m_Model))
689*3e777be0SXin Li {
690*3e777be0SXin Li callback(V1_0::ErrorStatus::INVALID_ARGUMENT, {}, g_NoTiming, "ArmnnPreparedModel_1_2::execute");
691*3e777be0SXin Li return V1_0::ErrorStatus::INVALID_ARGUMENT;
692*3e777be0SXin Li }
693*3e777be0SXin Li
694*3e777be0SXin Li if (!m_RequestInputsAndOutputsDumpDir.empty())
695*3e777be0SXin Li {
696*3e777be0SXin Li ALOGD("Dumping inputs and outputs for request %" PRIuPTR, reinterpret_cast<std::uintptr_t>(&callback));
697*3e777be0SXin Li }
698*3e777be0SXin Li
699*3e777be0SXin Li // map the memory pool into shared pointers
700*3e777be0SXin Li // use a shared memory pools vector on the heap, as it is passed to the request thread
701*3e777be0SXin Li auto memPools = std::make_shared<std::vector<android::nn::RunTimePoolInfo>>();
702*3e777be0SXin Li
703*3e777be0SXin Li // allocate the tensors on the heap, as they are passed to the request thread
704*3e777be0SXin Li auto inputTensors = std::make_shared<armnn::InputTensors>();
705*3e777be0SXin Li auto outputTensors = std::make_shared<armnn::OutputTensors>();
706*3e777be0SXin Li
707*3e777be0SXin Li auto prepareStatus = PrepareMemoryForIO(*inputTensors, *outputTensors, *memPools, request, callback);
708*3e777be0SXin Li switch(prepareStatus)
709*3e777be0SXin Li {
710*3e777be0SXin Li case V1_0::ErrorStatus::OUTPUT_INSUFFICIENT_SIZE:
711*3e777be0SXin Li return V1_0::ErrorStatus::NONE;
712*3e777be0SXin Li case V1_0::ErrorStatus::GENERAL_FAILURE:
713*3e777be0SXin Li return V1_0::ErrorStatus::GENERAL_FAILURE;
714*3e777be0SXin Li default:
715*3e777be0SXin Li {}
716*3e777be0SXin Li }
717*3e777be0SXin Li
718*3e777be0SXin Li
719*3e777be0SXin Li // post the request for asynchronous execution
720*3e777be0SXin Li CallbackContext_1_2 cb;
721*3e777be0SXin Li cb.callback = callback;
722*3e777be0SXin Li cb.ctx = ctx;
723*3e777be0SXin Li
724*3e777be0SXin Li if (m_AsyncModelExecutionEnabled)
725*3e777be0SXin Li {
726*3e777be0SXin Li ALOGV("ArmnnPreparedModel_1_2::execute(...) before ScheduleGraphForExecution");
727*3e777be0SXin Li ScheduleGraphForExecution(memPools, inputTensors, outputTensors, cb);
728*3e777be0SXin Li ALOGV("ArmnnPreparedModel_1_2::execute(...) after ScheduleGraphForExecution");
729*3e777be0SXin Li return V1_0::ErrorStatus::NONE;
730*3e777be0SXin Li }
731*3e777be0SXin Li
732*3e777be0SXin Li ALOGV("ArmnnPreparedModel_1_2::execute(...) before PostMsg");
733*3e777be0SXin Li m_RequestThread.PostMsg(this, memPools, inputTensors, outputTensors, cb);
734*3e777be0SXin Li ALOGV("ArmnnPreparedModel_1_2::execute(...) after PostMsg");
735*3e777be0SXin Li return V1_0::ErrorStatus::NONE;
736*3e777be0SXin Li }
737*3e777be0SXin Li
738*3e777be0SXin Li template<typename HalVersion>
configureExecutionBurst(const sp<V1_2::IBurstCallback> & callback,const MQDescriptorSync<V1_2::FmqRequestDatum> & requestChannel,const MQDescriptorSync<V1_2::FmqResultDatum> & resultChannel,V1_2::IPreparedModel::configureExecutionBurst_cb cb)739*3e777be0SXin Li Return<void> ArmnnPreparedModel_1_2<HalVersion>::configureExecutionBurst(
740*3e777be0SXin Li const sp<V1_2::IBurstCallback>& callback,
741*3e777be0SXin Li const MQDescriptorSync<V1_2::FmqRequestDatum>& requestChannel,
742*3e777be0SXin Li const MQDescriptorSync<V1_2::FmqResultDatum>& resultChannel,
743*3e777be0SXin Li V1_2::IPreparedModel::configureExecutionBurst_cb cb)
744*3e777be0SXin Li {
745*3e777be0SXin Li ALOGV("ArmnnPreparedModel_1_2::configureExecutionBurst");
746*3e777be0SXin Li const sp<V1_2::IBurstContext> burst = ExecutionBurstServer::create(callback,
747*3e777be0SXin Li requestChannel,
748*3e777be0SXin Li resultChannel,
749*3e777be0SXin Li this);
750*3e777be0SXin Li
751*3e777be0SXin Li if (burst == nullptr)
752*3e777be0SXin Li {
753*3e777be0SXin Li cb(V1_0::ErrorStatus::GENERAL_FAILURE, {});
754*3e777be0SXin Li }
755*3e777be0SXin Li else
756*3e777be0SXin Li {
757*3e777be0SXin Li cb(V1_0::ErrorStatus::NONE, burst);
758*3e777be0SXin Li }
759*3e777be0SXin Li return Void();
760*3e777be0SXin Li }
761*3e777be0SXin Li
762*3e777be0SXin Li /// Schedule the graph prepared from the request for execution
763*3e777be0SXin Li template<typename HalVersion>
764*3e777be0SXin Li template<typename CallbackContext>
ScheduleGraphForExecution(std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>> & pMemPools,std::shared_ptr<armnn::InputTensors> & inputTensors,std::shared_ptr<armnn::OutputTensors> & outputTensors,CallbackContext callbackContext)765*3e777be0SXin Li void ArmnnPreparedModel_1_2<HalVersion>::ScheduleGraphForExecution(
766*3e777be0SXin Li std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
767*3e777be0SXin Li std::shared_ptr<armnn::InputTensors>& inputTensors,
768*3e777be0SXin Li std::shared_ptr<armnn::OutputTensors>& outputTensors,
769*3e777be0SXin Li CallbackContext callbackContext)
770*3e777be0SXin Li {
771*3e777be0SXin Li ALOGV("ArmnnPreparedModel_1_2::ScheduleGraphForExecution(...)");
772*3e777be0SXin Li
773*3e777be0SXin Li DumpTensorsIfRequired("Input", *inputTensors);
774*3e777be0SXin Li
775*3e777be0SXin Li unsigned int outputTensorSize = outputTensors.get()->size();
776*3e777be0SXin Li std::vector<V1_2::OutputShape> outputShapes(outputTensorSize);
777*3e777be0SXin Li for (unsigned int i = 0; i < outputTensorSize; i++)
778*3e777be0SXin Li {
779*3e777be0SXin Li std::pair<int, armnn::Tensor> outputTensorPair = outputTensors.get()->at(i);
780*3e777be0SXin Li const armnn::Tensor outputTensor = outputTensorPair.second;
781*3e777be0SXin Li const armnn::TensorInfo outputTensorInfo = outputTensor.GetInfo();
782*3e777be0SXin Li
783*3e777be0SXin Li outputShapes[i] = ComputeShape(outputTensorInfo);
784*3e777be0SXin Li }
785*3e777be0SXin Li
786*3e777be0SXin Li auto tpCb = std::make_shared<
787*3e777be0SXin Li ArmnnThreadPoolCallback_1_2<CallbackContext_1_2>>(this,
788*3e777be0SXin Li pMemPools,
789*3e777be0SXin Li outputShapes,
790*3e777be0SXin Li inputTensors,
791*3e777be0SXin Li outputTensors,
792*3e777be0SXin Li callbackContext);
793*3e777be0SXin Li
794*3e777be0SXin Li m_Threadpool->Schedule(m_NetworkId,
795*3e777be0SXin Li *tpCb->m_InputTensors,
796*3e777be0SXin Li *tpCb->m_OutputTensors,
797*3e777be0SXin Li armnn::QosExecPriority::Medium,
798*3e777be0SXin Li tpCb);
799*3e777be0SXin Li ALOGV("ArmnnPreparedModel_1_2::ScheduleGraphForExecution end");
800*3e777be0SXin Li }
801*3e777be0SXin Li
802*3e777be0SXin Li template<typename HalVersion>
803*3e777be0SXin Li template <typename CallbackContext>
Notify(armnn::Status status,armnn::InferenceTimingPair timeTaken)804*3e777be0SXin Li void ArmnnPreparedModel_1_2<HalVersion>::ArmnnThreadPoolCallback_1_2<CallbackContext>::Notify(
805*3e777be0SXin Li armnn::Status status, armnn::InferenceTimingPair timeTaken)
806*3e777be0SXin Li {
807*3e777be0SXin Li ALOGV("ArmnnPreparedModel_1_2::ArmnnThreadPoolCallback_1_2 Notify");
808*3e777be0SXin Li
809*3e777be0SXin Li TimePoint driverEnd;
810*3e777be0SXin Li
811*3e777be0SXin Li CommitPools(*m_MemPools);
812*3e777be0SXin Li
813*3e777be0SXin Li m_Model->DumpTensorsIfRequired("Output", *m_OutputTensors);
814*3e777be0SXin Li
815*3e777be0SXin Li if (status != armnn::Status::Success)
816*3e777be0SXin Li {
817*3e777be0SXin Li ALOGW("ArmnnThreadPoolCallback::Notify EnqueueWorkload failed");
818*3e777be0SXin Li m_CallbackContext.callback(
819*3e777be0SXin Li V1_0::ErrorStatus::GENERAL_FAILURE, {}, g_NoTiming, "ArmnnPreparedModel::ExecuteGraph");
820*3e777be0SXin Li return;
821*3e777be0SXin Li }
822*3e777be0SXin Li
823*3e777be0SXin Li if (m_CallbackContext.ctx.measureTimings == V1_2::MeasureTiming::YES)
824*3e777be0SXin Li {
825*3e777be0SXin Li driverEnd = std::chrono::steady_clock::now();
826*3e777be0SXin Li V1_2::Timing timing;
827*3e777be0SXin Li timing.timeOnDevice = MicrosecondsDuration(timeTaken.second, timeTaken.first);
828*3e777be0SXin Li timing.timeInDriver = MicrosecondsDuration(driverEnd, m_CallbackContext.ctx.driverStart);
829*3e777be0SXin Li ALOGV("ArmnnPreparedModel_1_2::execute timing - Device = %lu Driver = %lu",
830*3e777be0SXin Li static_cast<unsigned long>(timing.timeOnDevice), static_cast<unsigned long>(timing.timeInDriver));
831*3e777be0SXin Li m_CallbackContext.callback(
832*3e777be0SXin Li V1_0::ErrorStatus::NONE, m_OutputShapes, timing, "ArmnnPreparedModel_1_2::ExecuteGraph");
833*3e777be0SXin Li } else {
834*3e777be0SXin Li m_CallbackContext.callback(
835*3e777be0SXin Li V1_0::ErrorStatus::NONE, m_OutputShapes, g_NoTiming, "ArmnnPreparedModel_1_2::ExecuteGraph");
836*3e777be0SXin Li }
837*3e777be0SXin Li return;
838*3e777be0SXin Li }
839*3e777be0SXin Li
840*3e777be0SXin Li #if defined(ARMNN_ANDROID_NN_V1_2) || defined(ARMNN_ANDROID_NN_V1_3)
841*3e777be0SXin Li template class ArmnnPreparedModel_1_2<hal_1_2::HalPolicy>;
842*3e777be0SXin Li template bool ArmnnPreparedModel_1_2<hal_1_2::HalPolicy>::ExecuteGraph<CallbackContext_1_2>(
843*3e777be0SXin Li std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
844*3e777be0SXin Li armnn::InputTensors& pInputTensors,
845*3e777be0SXin Li armnn::OutputTensors& pOutputTensors,
846*3e777be0SXin Li CallbackContext_1_2 cb);
847*3e777be0SXin Li
848*3e777be0SXin Li template void ArmnnPreparedModel_1_2<hal_1_2::HalPolicy>::ScheduleGraphForExecution<CallbackContext_1_2>(
849*3e777be0SXin Li std::shared_ptr<std::vector<::android::nn::RunTimePoolInfo>>& pMemPools,
850*3e777be0SXin Li std::shared_ptr<armnn::InputTensors>& inputTensors,
851*3e777be0SXin Li std::shared_ptr<armnn::OutputTensors>& outputTensors,
852*3e777be0SXin Li CallbackContext_1_2 callbackContext);
853*3e777be0SXin Li #endif
854*3e777be0SXin Li
855*3e777be0SXin Li } // namespace armnn_driver
856