xref: /aosp_15_r20/external/armnn/src/backends/reference/workloads/RefStridedSliceWorkload.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "RefStridedSliceWorkload.hpp"
7 #include "RefWorkloadUtils.hpp"
8 #include "StridedSlice.hpp"
9 
10 namespace armnn
11 {
12 
RefStridedSliceWorkload(const StridedSliceQueueDescriptor & descriptor,const WorkloadInfo & info)13 RefStridedSliceWorkload::RefStridedSliceWorkload(const StridedSliceQueueDescriptor& descriptor,
14                                                  const WorkloadInfo& info)
15     : RefBaseWorkload(descriptor, info)
16 {}
17 
Execute() const18 void RefStridedSliceWorkload::Execute() const
19 {
20     Execute(m_Data.m_Inputs, m_Data.m_Outputs);
21 }
22 
ExecuteAsync(ExecutionData & executionData)23 void RefStridedSliceWorkload::ExecuteAsync(ExecutionData& executionData)
24 {
25     WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
26     Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs);
27 }
28 
Execute(std::vector<ITensorHandle * > inputs,std::vector<ITensorHandle * > outputs) const29 void RefStridedSliceWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
30 {
31     ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefStridedSliceWorkload_Execute");
32 
33     const TensorInfo& inputInfo  = GetTensorInfo(inputs[0]);
34     const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
35 
36     DataType inputDataType  = inputInfo.GetDataType();
37     DataType outputDataType = outputInfo.GetDataType();
38 
39     ARMNN_ASSERT(inputDataType == outputDataType);
40     IgnoreUnused(outputDataType);
41 
42     StridedSlice(inputInfo,
43                  m_Data.m_Parameters,
44                  inputs[0]->Map(),
45                  outputs[0]->Map(),
46                  GetDataTypeSize(inputDataType));
47 }
48 
49 } // namespace armnn
50