1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "RefStridedSliceWorkload.hpp"
7 #include "RefWorkloadUtils.hpp"
8 #include "StridedSlice.hpp"
9
10 namespace armnn
11 {
12
RefStridedSliceWorkload(const StridedSliceQueueDescriptor & descriptor,const WorkloadInfo & info)13 RefStridedSliceWorkload::RefStridedSliceWorkload(const StridedSliceQueueDescriptor& descriptor,
14 const WorkloadInfo& info)
15 : RefBaseWorkload(descriptor, info)
16 {}
17
Execute() const18 void RefStridedSliceWorkload::Execute() const
19 {
20 Execute(m_Data.m_Inputs, m_Data.m_Outputs);
21 }
22
ExecuteAsync(ExecutionData & executionData)23 void RefStridedSliceWorkload::ExecuteAsync(ExecutionData& executionData)
24 {
25 WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
26 Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs);
27 }
28
Execute(std::vector<ITensorHandle * > inputs,std::vector<ITensorHandle * > outputs) const29 void RefStridedSliceWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
30 {
31 ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefStridedSliceWorkload_Execute");
32
33 const TensorInfo& inputInfo = GetTensorInfo(inputs[0]);
34 const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
35
36 DataType inputDataType = inputInfo.GetDataType();
37 DataType outputDataType = outputInfo.GetDataType();
38
39 ARMNN_ASSERT(inputDataType == outputDataType);
40 IgnoreUnused(outputDataType);
41
42 StridedSlice(inputInfo,
43 m_Data.m_Parameters,
44 inputs[0]->Map(),
45 outputs[0]->Map(),
46 GetDataTypeSize(inputDataType));
47 }
48
49 } // namespace armnn
50