xref: /aosp_15_r20/external/armnn/src/backends/reference/workloads/FullyConnected.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "FullyConnected.hpp"
7 
8 #include <armnn/utility/Assert.hpp>
9 
10 #include "RefWorkloadUtils.hpp"
11 
12 namespace armnn
13 {
14 
FullyConnected(const TensorShape & rInputShape,Decoder<float> & rInputDecoder,const TensorShape & rOutputShape,Encoder<float> & rOutputEncoder,const TensorShape & rWeightsShape,Decoder<float> & rWeightDecoder,Decoder<float> * pBiasDecoder,const bool biasEnabled,const unsigned int K,const bool transposeWeights)15 void FullyConnected(const TensorShape& rInputShape,
16                     Decoder<float>& rInputDecoder,
17                     const TensorShape& rOutputShape,
18                     Encoder<float>& rOutputEncoder,
19                     const TensorShape& rWeightsShape,
20                     Decoder<float>& rWeightDecoder,
21                     Decoder<float>* pBiasDecoder,
22                     const bool biasEnabled,
23                     const unsigned int K,
24                     const bool transposeWeights)
25 {
26     // Perform FullyConnected implementation
27     unsigned int outputSize = rOutputShape[1];
28 
29     const std::vector<float> decodedInputs = rInputDecoder.DecodeTensor(rInputShape);
30     const std::vector<float> decodedWeights = rWeightDecoder.DecodeTensor(rWeightsShape);
31 
32     const TensorShape biasShape{outputSize};
33 
34     ARMNN_ASSERT(!biasEnabled || pBiasDecoder != nullptr);
35     const std::vector<float> decodedBiases = biasEnabled ? pBiasDecoder->DecodeTensor(biasShape) : std::vector<float>();
36 
37 
38     for (unsigned int n = 0; n < rInputShape[0]; n++)
39     {
40         for (unsigned int channelOutput = 0; channelOutput < outputSize; channelOutput++)
41         {
42             float outval = 0.f;
43 
44             for (unsigned int channelInput = 0; channelInput < K; channelInput++)
45             {
46                 float weight;
47                 if (transposeWeights)
48                 {
49                     weight = decodedWeights[channelOutput * K + channelInput];
50                 }
51                 else
52                 {
53                     weight = decodedWeights[channelInput * outputSize + channelOutput];
54                 }
55 
56                 outval += weight * decodedInputs[n * K + channelInput];
57             }
58 
59             if (biasEnabled)
60             {
61                 outval += decodedBiases[channelOutput];
62             }
63 
64             rOutputEncoder[n * outputSize + channelOutput];
65             rOutputEncoder.Set(outval);
66         }
67     }
68 }
69 
70 } //namespace armnn
71