xref: /aosp_15_r20/external/armnn/src/backends/cl/workloads/ClSqrtWorkload.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ClSqrtWorkload.hpp"
7 
8 #include "ClWorkloadUtils.hpp"
9 
10 #include <aclCommon/ArmComputeTensorUtils.hpp>
11 #include <aclCommon/ArmComputeUtils.hpp>
12 #include <armnn/utility/PolymorphicDowncast.hpp>
13 
14 #include <cl/ClTensorHandle.hpp>
15 
16 namespace armnn
17 {
18 
ClSqrtWorkloadValidate(const TensorInfo & input,const TensorInfo & output)19 arm_compute::Status ClSqrtWorkloadValidate(const TensorInfo& input, const TensorInfo& output)
20 {
21     const arm_compute::TensorInfo aclInput  = armcomputetensorutils::BuildArmComputeTensorInfo(input);
22     const arm_compute::TensorInfo aclOutput = armcomputetensorutils::BuildArmComputeTensorInfo(output);
23 
24     ActivationDescriptor descriptor;
25     descriptor.m_Function = ActivationFunction::Sqrt;
26     const arm_compute::ActivationLayerInfo activationLayerInfo =
27             ConvertActivationDescriptorToAclActivationLayerInfo(descriptor);
28 
29     return arm_compute::CLActivationLayer::validate(&aclInput, &aclOutput, activationLayerInfo);
30 }
31 
ClSqrtWorkload(const ElementwiseUnaryQueueDescriptor & descriptor,const WorkloadInfo & info,const arm_compute::CLCompileContext & clCompileContext)32 ClSqrtWorkload::ClSqrtWorkload(const ElementwiseUnaryQueueDescriptor& descriptor,
33                                const WorkloadInfo& info,
34                                const arm_compute::CLCompileContext& clCompileContext)
35     : ClBaseWorkload<ElementwiseUnaryQueueDescriptor>(descriptor, info)
36 {
37     ARMNN_ASSERT(descriptor.m_Parameters.m_Operation == UnaryOperation::Sqrt);
38 
39     // Report Profiling Details
40     ARMNN_REPORT_PROFILING_WORKLOAD_DESC("ClSqrtWorkload_Construct",
41                                          descriptor.m_Parameters,
42                                          info,
43                                          this->GetGuid());
44 
45     m_Data.ValidateInputsOutputs("ClSqrtWorkload", 1, 1);
46 
47     ActivationDescriptor activationDescriptor;
48     activationDescriptor.m_Function = ActivationFunction::Sqrt;
49     const arm_compute::ActivationLayerInfo activationLayerInfo =
50             ConvertActivationDescriptorToAclActivationLayerInfo(activationDescriptor);
51 
52     arm_compute::ICLTensor& input  = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
53     arm_compute::ICLTensor& output = PolymorphicDowncast<ClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
54 
55     {
56         ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "ClSqrtWorkload_configure");
57         m_SqrtLayer.configure(clCompileContext, &input, &output, activationLayerInfo);
58     }
59 }
60 
Execute() const61 void ClSqrtWorkload::Execute() const
62 {
63     ARMNN_SCOPED_PROFILING_EVENT_CL_GUID("ClSqrtWorkload_Execute", this->GetGuid());
64     RunClFunction(m_SqrtLayer, CHECK_LOCATION());
65 }
66 
67 } // namespace armnn
68