xref: /aosp_15_r20/external/armnn/src/backends/reference/workloads/Softmax.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "Softmax.hpp"
7 
8 #include <armnnUtils/TensorUtils.hpp>
9 
10 #include <cmath>
11 #include <vector>
12 
13 namespace armnn
14 {
15 
16 /// Computes the softmax function on some inputs, into outputs, with a shape given by tensorInfo.
Softmax(Decoder<float> & in,Encoder<float> & out,const TensorInfo & inputTensorInfo,float beta,int axis)17 void Softmax(Decoder<float>& in, Encoder<float>& out, const TensorInfo& inputTensorInfo, float beta, int axis)
18 {
19     ARMNN_ASSERT_MSG(axis < static_cast<int>(inputTensorInfo.GetNumDimensions()),
20                      "Required axis index greater than number of dimensions.");
21     ARMNN_ASSERT_MSG(axis >= -static_cast<int>(inputTensorInfo.GetNumDimensions()),
22                      "Required axis index lower than negative of the number of dimensions");
23 
24     unsigned int uAxis = axis < 0  ?
25                          inputTensorInfo.GetNumDimensions() - static_cast<unsigned int>(abs(axis))
26                          : static_cast<unsigned int>(axis);
27 
28     const TensorShape& inputShape = inputTensorInfo.GetShape();
29     const unsigned int outerSize  = armnnUtils::GetNumElementsBetween(inputShape, 0, uAxis);
30     const unsigned int axisSize   = inputShape[uAxis];
31     const unsigned int innerSize  = armnnUtils::GetNumElementsBetween(inputShape,
32                                                                       uAxis + 1,
33                                                                       inputShape.GetNumDimensions());
34 
35     for (unsigned int outer = 0; outer < outerSize; ++outer)
36     {
37         unsigned int inputBeginIdx  = outer * axisSize * innerSize;
38         unsigned int inputEndIdx    = inputBeginIdx + axisSize * innerSize;
39         unsigned int outputBeginIdx = outer * axisSize * innerSize;
40 
41         for (unsigned int inner = 0; inner < innerSize; ++inner, ++inputBeginIdx, ++inputEndIdx, ++outputBeginIdx)
42         {
43             // Find max
44             float maxValue = std::numeric_limits<float>::lowest();
45             for (unsigned int iter = inputBeginIdx; iter < inputEndIdx; iter += innerSize)
46             {
47                 in[iter];
48                 maxValue = std::max(maxValue, in.Get());
49             }
50 
51             // Compute sum
52             float sum = 0.0f;
53             for (unsigned int iter = inputBeginIdx; iter < inputEndIdx; iter += innerSize)
54             {
55                 in[iter];
56                 sum += std::exp((in.Get() - maxValue) * beta);
57             }
58 
59             // Compute result
60             unsigned int outputIter = outputBeginIdx;
61             out[outputIter];
62             for (unsigned int iter = inputBeginIdx; iter < inputEndIdx; iter += innerSize, outputIter += innerSize)
63             {
64                 out[outputIter];
65                 in[iter];
66                 out.Set(std::exp((in.Get() - maxValue) * beta) / sum);
67             }
68         }
69     }
70 }
71 
72 } //namespace armnn
73