xref: /aosp_15_r20/external/armnn/src/backends/cl/workloads/ClMeanWorkload.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ClMeanWorkload.hpp"
7 
8 #include <cl/ClTensorHandle.hpp>
9 #include <aclCommon/ArmComputeTensorUtils.hpp>
10 
11 #include "ClWorkloadUtils.hpp"
12 
13 namespace armnn
14 {
15 using namespace armcomputetensorutils;
16 
ClMeanValidate(const TensorInfo & input,const TensorInfo & output,const MeanDescriptor & descriptor)17 arm_compute::Status ClMeanValidate(const TensorInfo& input,
18                                    const TensorInfo& output,
19                                    const MeanDescriptor& descriptor)
20 {
21     const arm_compute::TensorInfo aclInputInfo  = armcomputetensorutils::BuildArmComputeTensorInfo(input);
22     const arm_compute::TensorInfo aclOutputInfo = armcomputetensorutils::BuildArmComputeTensorInfo(output);
23 
24     arm_compute::Coordinates coords = BuildArmComputeReductionCoordinates(aclInputInfo.num_dimensions(),
25                                                                           input.GetNumDimensions(),
26                                                                           descriptor.m_Axis);
27 
28     return arm_compute::CLReduceMean::validate(&aclInputInfo, coords, descriptor.m_KeepDims, &aclOutputInfo);
29 }
30 
ClMeanWorkload(const MeanQueueDescriptor & descriptor,const WorkloadInfo & info,const arm_compute::CLCompileContext & clCompileContext)31 ClMeanWorkload::ClMeanWorkload(const MeanQueueDescriptor& descriptor,
32                                const WorkloadInfo& info,
33                                const arm_compute::CLCompileContext& clCompileContext)
34     : ClBaseWorkload<MeanQueueDescriptor>(descriptor, info)
35 {
36     // Report Profiling Details
37     ARMNN_REPORT_PROFILING_WORKLOAD_DESC("ClMeanWorkload_Construct",
38                                          descriptor.m_Parameters,
39                                          info,
40                                          this->GetGuid());
41     m_Data.ValidateInputsOutputs("ClMeanWorkload", 1, 1);
42 
43     arm_compute::ICLTensor& input  = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
44     arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
45 
46     arm_compute::Coordinates coords = BuildArmComputeReductionCoordinates(input.info()->num_dimensions(),
47                                                                           info.m_InputTensorInfos[0].GetNumDimensions(),
48                                                                           m_Data.m_Parameters.m_Axis);
49 
50     {
51         ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "ClMeanWorkload_configure");
52         m_Layer.configure(clCompileContext, &input, coords, m_Data.m_Parameters.m_KeepDims, &output);
53     }
54 }
55 
Execute() const56 void ClMeanWorkload::Execute() const
57 {
58     ARMNN_SCOPED_PROFILING_EVENT_CL_GUID("ClMeanWorkload_Execute", this->GetGuid());
59     m_Layer.run();
60 }
61 
62 } //namespace armnn
63