xref: /aosp_15_r20/external/armnn/src/backends/reference/workloads/RefReduceWorkload.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2020 Samsung Electronics Co Ltd and Contributors. All rights reserved.
3 // Copyright © 2021-2022 Arm Ltd and Contributors. All rights reserved.
4 // SPDX-License-Identifier: MIT
5 //
6 
7 #include "RefReduceWorkload.hpp"
8 
9 #include "Reduce.hpp"
10 #include "RefWorkloadUtils.hpp"
11 #include "BaseIterator.hpp"
12 #include "Profiling.hpp"
13 
14 namespace armnn
15 {
16 
RefReduceWorkload(const ReduceQueueDescriptor & descriptor,const WorkloadInfo & info)17 RefReduceWorkload::RefReduceWorkload(
18     const ReduceQueueDescriptor& descriptor,
19     const WorkloadInfo& info)
20     : RefBaseWorkload<ReduceQueueDescriptor>(descriptor, info) {}
21 
Execute() const22 void RefReduceWorkload::Execute() const
23 {
24     Execute(m_Data.m_Inputs, m_Data.m_Outputs);
25 }
26 
ExecuteAsync(ExecutionData & executionData)27 void RefReduceWorkload::ExecuteAsync(ExecutionData& executionData)
28 {
29     WorkingMemDescriptor* workingMemDescriptor = static_cast<WorkingMemDescriptor*>(executionData.m_Data);
30     Execute(workingMemDescriptor->m_Inputs, workingMemDescriptor->m_Outputs);
31 }
32 
Execute(std::vector<ITensorHandle * > inputs,std::vector<ITensorHandle * > outputs) const33 void RefReduceWorkload::Execute(std::vector<ITensorHandle*> inputs, std::vector<ITensorHandle*> outputs) const
34 {
35     ARMNN_SCOPED_PROFILING_EVENT(Compute::CpuRef, "RefReduceWorkload_Execute");
36 
37     const TensorInfo& inputInfo  = GetTensorInfo(inputs[0]);
38     const TensorInfo& outputInfo = GetTensorInfo(outputs[0]);
39 
40     std::unique_ptr<Decoder<float>> decoderPtr = MakeDecoder<float>(inputInfo, inputs[0]->Map());
41     Decoder<float>& decoder = *decoderPtr;
42 
43     std::unique_ptr<Encoder<float>> encoderPtr = MakeEncoder<float>(outputInfo, outputs[0]->Map());
44     Encoder<float>& encoder = *encoderPtr;
45 
46     Reduce(inputInfo,
47            outputInfo,
48            decoder,
49            encoder,
50            m_Data.m_Parameters.m_vAxis,
51            m_Data.m_Parameters.m_ReduceOperation);
52 }
53 
54 } //namespace armnn
55