xref: /aosp_15_r20/external/armnn/src/backends/cl/workloads/ClBatchNormalizationFloatWorkload.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "ClBaseWorkload.hpp"
9 
10 #include <arm_compute/runtime/CL/CLTensor.h>
11 #include <arm_compute/runtime/CL/functions/CLBatchNormalizationLayer.h>
12 
13 namespace armnn
14 {
15 
16 arm_compute::Status ClBatchNormalizationValidate(const TensorInfo& input,
17                                                  const TensorInfo& output,
18                                                  const TensorInfo& mean,
19                                                  const TensorInfo& var,
20                                                  const TensorInfo& beta,
21                                                  const TensorInfo& gamma,
22                                                  const BatchNormalizationDescriptor& descriptor,
23                                                  const ActivationDescriptor* activationDescriptor = nullptr);
24 
25 class ClBatchNormalizationFloatWorkload : public FloatWorkload<BatchNormalizationQueueDescriptor>
26 {
27 public:
28     ClBatchNormalizationFloatWorkload(const BatchNormalizationQueueDescriptor& descriptor,
29                                       const WorkloadInfo& info,
30                                       const arm_compute::CLCompileContext& clCompileContext);
31 
32     using FloatWorkload<BatchNormalizationQueueDescriptor>::FloatWorkload;
33     void Execute() const override;
34 
35     // Replace input tensor handle with the given TensorHandle
36     void ReplaceInputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override;
37 
38     // Replace output tensor handle with the given TensorHandle
39     void ReplaceOutputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override;
40 
41 private:
42     mutable arm_compute::CLBatchNormalizationLayer m_Layer;
43 
44     std::unique_ptr<arm_compute::CLTensor> m_Mean;
45     std::unique_ptr<arm_compute::CLTensor> m_Variance;
46     std::unique_ptr<arm_compute::CLTensor> m_Gamma;
47     std::unique_ptr<arm_compute::CLTensor> m_Beta;
48 
49     void FreeUnusedTensors();
50     virtual void Reconfigure();
51 };
52 
53 } //namespace armnn
54 
55 
56 
57 
58