xref: /aosp_15_r20/external/armnn/src/backends/reference/workloads/TransposeConvolution2d.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "TransposeConvolution2d.hpp"
7 
8 #include <armnnUtils/DataLayoutIndexed.hpp>
9 
10 namespace armnn
11 {
12 
13 using namespace armnnUtils;
14 
TransposeConvolution2dImpl(const TransposeConvolution2dDescriptor & descriptor,const TensorShape & inputShape,Decoder<float> & inputDecoder,const TensorShape & outputShape,Encoder<float> & outputEncoder,const TensorShape & weightsShape,Decoder<float> & weightsDecoder,Decoder<float> * biasesDecoder)15 void TransposeConvolution2dImpl(const TransposeConvolution2dDescriptor& descriptor,
16                                 const TensorShape& inputShape,
17                                 Decoder<float>& inputDecoder,
18                                 const TensorShape& outputShape,
19                                 Encoder<float>& outputEncoder,
20                                 const TensorShape& weightsShape,
21                                 Decoder<float>& weightsDecoder,
22                                 Decoder<float>* biasesDecoder)
23 {
24     if (descriptor.m_BiasEnabled && !biasesDecoder)
25     {
26         throw InvalidArgumentException("Biases enabled but no bias data provided");
27     }
28     const DataLayoutIndexed dataLayoutIndexed(descriptor.m_DataLayout);
29     const unsigned int channelsIndex = dataLayoutIndexed.GetChannelsIndex();
30     const unsigned int heightIndex   = dataLayoutIndexed.GetHeightIndex();
31     const unsigned int widthIndex    = dataLayoutIndexed.GetWidthIndex();
32 
33     const unsigned int numBatches = inputShape[0];
34 
35     const unsigned int inputWidth  = inputShape[widthIndex];
36     const unsigned int inputHeight = inputShape[heightIndex];
37     const unsigned int inputDepth  = inputShape[channelsIndex];
38 
39     const unsigned int weightsHeight = weightsShape[heightIndex];
40     const unsigned int weightsWidth  = weightsShape[widthIndex];
41     const unsigned int weightsDepth  = weightsShape[channelsIndex];
42 
43     const unsigned int outputHeight = outputShape[heightIndex];
44     const unsigned int outputWidth  = outputShape[widthIndex];
45     const unsigned int outputDepth  = outputShape[channelsIndex];
46 
47     const unsigned int paddingLeft = descriptor.m_PadLeft;
48     const unsigned int paddingTop  = descriptor.m_PadTop;
49 
50     const unsigned int strideX = descriptor.m_StrideX;
51     const unsigned int strideY = descriptor.m_StrideY;
52 
53     std::vector<float> outputBuffer(outputShape.GetNumElements(), 0);
54 
55     const std::vector<float> inputVec = inputDecoder.DecodeTensor(inputShape);
56     const std::vector<float> filterVec = weightsDecoder.DecodeTensor(weightsShape);
57 
58     for (unsigned int batch = 0u; batch < numBatches; ++batch)
59     {
60         for (unsigned int yInput = 0u; yInput < inputHeight; ++yInput)
61         {
62             for (unsigned int xInput = 0u; xInput < inputWidth; ++xInput)
63             {
64                 unsigned int xOutputOrigin = xInput * strideX - paddingLeft;
65                 unsigned int yOutputOrigin = yInput * strideY - paddingTop;
66 
67                 for (unsigned int dOutput = 0u; dOutput < outputDepth; ++dOutput)
68                 {
69                     for (unsigned int yWeights = 0u; yWeights < weightsHeight; ++yWeights)
70                     {
71                         for (unsigned int xWeights = 0u; xWeights < weightsWidth; ++xWeights)
72                         {
73                             unsigned int yOutput = yOutputOrigin + yWeights;
74                             unsigned int xOutput = xOutputOrigin + xWeights;
75 
76                             if (yOutput < outputHeight && xOutput< outputWidth)
77                             {
78                                 for (unsigned int dInput = 0u; dInput < inputDepth; dInput++)
79                                 {
80                                     unsigned int inputIndex;
81                                     unsigned int outputIndex;
82                                     unsigned int weightsIndex;
83 
84                                     if(descriptor.m_DataLayout == armnn::DataLayout::NHWC)
85                                     {
86                                         inputIndex   = batch  * inputHeight * inputWidth * inputDepth +
87                                                        yInput * inputWidth * inputDepth +
88                                                        xInput * inputDepth +
89                                                        dInput;
90 
91                                         weightsIndex = dOutput  * weightsHeight * weightsWidth * weightsDepth +
92                                                        yWeights * weightsWidth * weightsDepth +
93                                                        xWeights * weightsDepth +
94                                                        dInput;
95 
96                                         outputIndex  = batch   * outputHeight * outputWidth * outputDepth +
97                                                        yOutput * outputWidth * outputDepth +
98                                                        xOutput * outputDepth +
99                                                        dOutput;
100                                     }
101                                     else
102                                     {
103                                         inputIndex   = batch  * inputDepth * inputHeight * inputWidth +
104                                                        dInput * inputHeight * inputWidth +
105                                                        yInput * inputWidth +
106                                                        xInput;
107 
108                                         weightsIndex = dOutput  * weightsDepth * weightsHeight * weightsWidth +
109                                                        dInput   * weightsHeight * weightsWidth +
110                                                        yWeights * weightsWidth +
111                                                        xWeights;
112 
113                                         outputIndex  = batch   * outputDepth * outputHeight * outputWidth +
114                                                        dOutput * outputHeight * outputWidth +
115                                                        yOutput * outputWidth +
116                                                        xOutput;
117                                     }
118 
119                                     outputBuffer[outputIndex] += inputVec[inputIndex] * filterVec[weightsIndex];
120                                 }
121                             }
122                         }
123                     }
124 
125                 }
126             }
127         }
128     }
129 
130     // Apply bias (if enabled)
131     if (descriptor.m_BiasEnabled)
132     {
133         outputEncoder[0];
134         Decoder<float>& rBiasesDecoder = *biasesDecoder;
135 
136         for (unsigned int batch = 0u; batch < numBatches; ++batch)
137         {
138             for (unsigned int dOutput = 0u; dOutput < outputDepth; ++dOutput)
139             {
140                 rBiasesDecoder[dOutput];
141                 for (unsigned int yOutput = 0u; yOutput < outputHeight; ++yOutput)
142                 {
143                     for (unsigned int xOutput = 0u; xOutput < outputWidth; ++xOutput)
144                     {
145                         const unsigned int outputIndex =
146                             dataLayoutIndexed.GetIndex(outputShape, batch, dOutput, yOutput, xOutput);
147                         outputBuffer[outputIndex] += rBiasesDecoder.Get();
148                     }
149                 }
150             }
151         }
152     }
153     outputEncoder[0];
154     for (float output : outputBuffer)
155     {
156         outputEncoder.Set(output);
157         ++outputEncoder;
158     }
159 }
160 
161 } // namespace armnn
162