1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #include <ResolveType.hpp>
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include "WorkloadUtils.hpp"
9*89c4ff92SAndroid Build Coastguard Worker
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/MemCopyWorkload.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/TensorHandle.hpp>
12*89c4ff92SAndroid Build Coastguard Worker
13*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/PolymorphicDowncast.hpp>
14*89c4ff92SAndroid Build Coastguard Worker
15*89c4ff92SAndroid Build Coastguard Worker #include <cstring>
16*89c4ff92SAndroid Build Coastguard Worker
17*89c4ff92SAndroid Build Coastguard Worker namespace armnn
18*89c4ff92SAndroid Build Coastguard Worker {
19*89c4ff92SAndroid Build Coastguard Worker
20*89c4ff92SAndroid Build Coastguard Worker namespace
21*89c4ff92SAndroid Build Coastguard Worker {
22*89c4ff92SAndroid Build Coastguard Worker
23*89c4ff92SAndroid Build Coastguard Worker template <typename SrcTensorHandleType, typename DstTensorHandleType>
GatherTensorHandlePairs(const MemCopyQueueDescriptor & descriptor,std::vector<std::pair<SrcTensorHandleType *,DstTensorHandleType * >> & tensorHandlePairs)24*89c4ff92SAndroid Build Coastguard Worker void GatherTensorHandlePairs(const MemCopyQueueDescriptor& descriptor,
25*89c4ff92SAndroid Build Coastguard Worker std::vector<std::pair<SrcTensorHandleType*, DstTensorHandleType*>>& tensorHandlePairs)
26*89c4ff92SAndroid Build Coastguard Worker {
27*89c4ff92SAndroid Build Coastguard Worker const unsigned int numInputs = static_cast<unsigned int>(descriptor.m_Inputs.size());
28*89c4ff92SAndroid Build Coastguard Worker tensorHandlePairs.reserve(numInputs);
29*89c4ff92SAndroid Build Coastguard Worker
30*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < numInputs; ++i)
31*89c4ff92SAndroid Build Coastguard Worker {
32*89c4ff92SAndroid Build Coastguard Worker SrcTensorHandleType* const srcTensorHandle = PolymorphicDowncast<SrcTensorHandleType*>(
33*89c4ff92SAndroid Build Coastguard Worker descriptor.m_Inputs[i]);
34*89c4ff92SAndroid Build Coastguard Worker DstTensorHandleType* const dstTensorHandle = PolymorphicDowncast<DstTensorHandleType*>(
35*89c4ff92SAndroid Build Coastguard Worker descriptor.m_Outputs[i]);
36*89c4ff92SAndroid Build Coastguard Worker
37*89c4ff92SAndroid Build Coastguard Worker tensorHandlePairs.emplace_back(srcTensorHandle, dstTensorHandle);
38*89c4ff92SAndroid Build Coastguard Worker }
39*89c4ff92SAndroid Build Coastguard Worker }
40*89c4ff92SAndroid Build Coastguard Worker
41*89c4ff92SAndroid Build Coastguard Worker } //namespace
42*89c4ff92SAndroid Build Coastguard Worker
43*89c4ff92SAndroid Build Coastguard Worker
CopyMemGenericWorkload(const MemCopyQueueDescriptor & descriptor,const WorkloadInfo & info)44*89c4ff92SAndroid Build Coastguard Worker CopyMemGenericWorkload::CopyMemGenericWorkload(const MemCopyQueueDescriptor& descriptor,
45*89c4ff92SAndroid Build Coastguard Worker const WorkloadInfo& info)
46*89c4ff92SAndroid Build Coastguard Worker : BaseWorkload<MemCopyQueueDescriptor>(descriptor, info)
47*89c4ff92SAndroid Build Coastguard Worker {
48*89c4ff92SAndroid Build Coastguard Worker GatherTensorHandlePairs(descriptor, m_TensorHandlePairs);
49*89c4ff92SAndroid Build Coastguard Worker }
50*89c4ff92SAndroid Build Coastguard Worker
Execute() const51*89c4ff92SAndroid Build Coastguard Worker void CopyMemGenericWorkload::Execute() const
52*89c4ff92SAndroid Build Coastguard Worker {
53*89c4ff92SAndroid Build Coastguard Worker ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "CopyMemGeneric_Execute");
54*89c4ff92SAndroid Build Coastguard Worker
55*89c4ff92SAndroid Build Coastguard Worker auto copyFunc = [](void* dst, const void* src, size_t size)
56*89c4ff92SAndroid Build Coastguard Worker {
57*89c4ff92SAndroid Build Coastguard Worker memcpy(dst, src, size);
58*89c4ff92SAndroid Build Coastguard Worker };
59*89c4ff92SAndroid Build Coastguard Worker
60*89c4ff92SAndroid Build Coastguard Worker for (const auto& pair : m_TensorHandlePairs)
61*89c4ff92SAndroid Build Coastguard Worker {
62*89c4ff92SAndroid Build Coastguard Worker CopyTensorContentsGeneric(pair.first, pair.second, copyFunc);
63*89c4ff92SAndroid Build Coastguard Worker }
64*89c4ff92SAndroid Build Coastguard Worker }
65*89c4ff92SAndroid Build Coastguard Worker
ExecuteAsync(ExecutionData & executionData)66*89c4ff92SAndroid Build Coastguard Worker void CopyMemGenericWorkload::ExecuteAsync(ExecutionData& executionData)
67*89c4ff92SAndroid Build Coastguard Worker {
68*89c4ff92SAndroid Build Coastguard Worker ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "CopyMemGeneric_Execute_WorkingMemDescriptor");
69*89c4ff92SAndroid Build Coastguard Worker
70*89c4ff92SAndroid Build Coastguard Worker WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
71*89c4ff92SAndroid Build Coastguard Worker std::vector<TensorHandlePair> tensorHandlePairs;
72*89c4ff92SAndroid Build Coastguard Worker
73*89c4ff92SAndroid Build Coastguard Worker GatherTensorHandlePairs(*workingMemDescriptor, tensorHandlePairs);
74*89c4ff92SAndroid Build Coastguard Worker
75*89c4ff92SAndroid Build Coastguard Worker auto copyFunc = [](void* dst, const void* src, size_t size)
76*89c4ff92SAndroid Build Coastguard Worker {
77*89c4ff92SAndroid Build Coastguard Worker memcpy(dst, src, size);
78*89c4ff92SAndroid Build Coastguard Worker };
79*89c4ff92SAndroid Build Coastguard Worker
80*89c4ff92SAndroid Build Coastguard Worker for (const auto& pair : tensorHandlePairs)
81*89c4ff92SAndroid Build Coastguard Worker {
82*89c4ff92SAndroid Build Coastguard Worker CopyTensorContentsGeneric(pair.first, pair.second, copyFunc);
83*89c4ff92SAndroid Build Coastguard Worker }
84*89c4ff92SAndroid Build Coastguard Worker }
85*89c4ff92SAndroid Build Coastguard Worker
86*89c4ff92SAndroid Build Coastguard Worker } //namespace armnn
87