xref: /aosp_15_r20/external/armnn/src/backends/reference/workloads/BatchNormImpl.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "BatchNormImpl.hpp"
7 #include "RefWorkloadUtils.hpp"
8 
9 #include <armnn/Tensor.hpp>
10 
11 #include <armnnUtils/DataLayoutIndexed.hpp>
12 
13 #include <cmath>
14 
15 namespace armnn
16 {
17 
BatchNormImpl(const BatchNormalizationQueueDescriptor & data,Decoder<float> & meanDecoder,Decoder<float> & varianceDecoder,Decoder<float> & betaDecoder,Decoder<float> & gammaDecoder,Decoder<float> & inputDecoder,Encoder<float> & outputEncoder)18 void BatchNormImpl(const BatchNormalizationQueueDescriptor& data,
19                    Decoder<float>& meanDecoder,
20                    Decoder<float>& varianceDecoder,
21                    Decoder<float>& betaDecoder,
22                    Decoder<float>& gammaDecoder,
23                    Decoder<float>& inputDecoder,
24                    Encoder<float>& outputEncoder)
25 {
26     const TensorInfo& inputInfo = GetTensorInfo(data.m_Inputs[0]);
27     const TensorShape inputShape = inputInfo.GetShape();
28 
29     armnnUtils::DataLayoutIndexed dataLayout(data.m_Parameters.m_DataLayout);
30 
31     unsigned int inputBatches  = inputShape[0];
32     unsigned int inputHeight   = inputShape[dataLayout.GetHeightIndex()];
33     unsigned int inputWidth    = inputShape[dataLayout.GetWidthIndex()];
34     unsigned int inputChannels = inputShape[dataLayout.GetChannelsIndex()];
35 
36     for (unsigned int c = 0; c < inputChannels; c++)
37     {
38         meanDecoder[c];
39         varianceDecoder[c];
40         betaDecoder[c];
41         gammaDecoder[c];
42         float mean  = meanDecoder.Get();
43         float var   = varianceDecoder.Get();
44         float beta  = betaDecoder.Get();
45         float gamma = gammaDecoder.Get();
46 
47         float mult = gamma / sqrtf(var + data.m_Parameters.m_Eps);
48         float add  = beta - mult * mean;
49 
50         for (unsigned int n = 0; n < inputBatches; n++)
51         {
52             for (unsigned int h = 0; h < inputHeight; h++)
53             {
54                 for (unsigned int w = 0; w < inputWidth; w++)
55                 {
56                     unsigned int index = dataLayout.GetIndex(inputShape, n, c, h, w);
57                     inputDecoder[index];
58                     outputEncoder[index];
59                     outputEncoder.Set(mult * inputDecoder.Get() + add);
60                 }
61             }
62         }
63     }
64 }
65 
66 } // namespace armnn
67