xref: /aosp_15_r20/external/armnn/src/backends/reference/workloads/Stack.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "Stack.hpp"
7 #include "RefWorkloadUtils.hpp"
8 
9 namespace armnn
10 {
11 
Stack(const StackQueueDescriptor & data,std::vector<std::unique_ptr<Decoder<float>>> & inputs,Encoder<float> & output,const TensorInfo & inputInfo,const TensorInfo & outputInfo)12 void Stack(const StackQueueDescriptor& data,
13            std::vector<std::unique_ptr<Decoder<float>>>& inputs,
14            Encoder<float>& output,
15            const TensorInfo& inputInfo,
16            const TensorInfo& outputInfo)
17 {
18     unsigned int outputNumDims = outputInfo.GetNumDimensions();
19     unsigned int inputNumDims = inputInfo.GetNumDimensions();
20 
21     const armnn::TensorShape& outputDims = outputInfo.GetShape();
22     const armnn::TensorShape& inputDims = inputInfo.GetShape();
23 
24     unsigned int axis = data.m_Parameters.m_Axis;
25 
26     // Can perform a simple concatenation when axis == 0
27     if (!axis)
28     {
29         unsigned int numInputs = data.m_Parameters.m_NumInputs;
30         unsigned int inputLength = inputInfo.GetNumElements();
31 
32         for (unsigned int inputIdx=0; inputIdx<numInputs; ++inputIdx)
33         {
34             for (unsigned int elmt=0; elmt<inputLength; ++elmt)
35             {
36                 (*inputs[inputIdx])[elmt];
37                 output[(inputIdx * inputLength) + elmt];
38                 output.Set(inputs[inputIdx]->Get());
39             }
40         }
41         return;
42     }
43 
44     const unsigned int iNumTensors = static_cast<unsigned int>(data.m_Inputs.size());
45     const unsigned int iBatchSize  = inputDims[0];
46     const unsigned int iChannels   = (inputNumDims > 1) ? inputDims[1] : 1;
47     const unsigned int iHeight     = (inputNumDims > 2) ? inputDims[2] : 1;
48     const unsigned int iWidth      = (inputNumDims > 3) ? inputDims[3] : 1;
49 
50     const unsigned int oBatchSize  = outputDims[1];
51     const unsigned int oChannels   = (outputNumDims > 2) ? outputDims[2] : 1;
52     const unsigned int oHeight     = (outputNumDims > 3) ? outputDims[3] : 1;
53     const unsigned int oWidth      = (outputNumDims > 4) ? outputDims[4] : 1;
54 
55     // Array to store the input coordinates
56     // iCoordinates[0] = i, iCoordinates[1] = bi, iCoordinates[2] = ci
57     // iCoordinates[3] = hi, iCoordinates[4] = wi, iCoordinates[5] = 0
58     // iCoordinates[5] will be always zero and used for not incrementing
59     // the output when the input has less than 4 dimensions
60     std::array<unsigned int, 6> iCoordinates{ 0 };
61 
62     // Array of pointers used to map the output coordinates to the input ones, in accordance with the axis
63     // This array is initialized with &iCoordinates[5] since this will be always zero
64     std::array<unsigned int *, 5> oCoordinates = { &iCoordinates[5],
65                                                    &iCoordinates[5],
66                                                    &iCoordinates[5],
67                                                    &iCoordinates[5],
68                                                    &iCoordinates[5] };
69 
70     // Set the axis coordinate
71     oCoordinates[axis] = &iCoordinates[0];
72 
73     // Map the output coordinates, accounting for the axis
74     unsigned int dim_shift = 0;
75     for(unsigned int dim = 0; dim < inputNumDims; ++dim)
76     {
77         if(dim == axis)
78         {
79             dim_shift++;
80         }
81         oCoordinates[dim + dim_shift] = &iCoordinates[dim + 1];
82     }
83 
84     // Alias for the input coordinates
85     unsigned int &i  = iCoordinates[0];
86     unsigned int &bi = iCoordinates[1];
87     unsigned int &ci = iCoordinates[2];
88     unsigned int &hi = iCoordinates[3];
89     unsigned int &wi = iCoordinates[4];
90 
91     // Alias for the output coordinates
92     unsigned int &o  = *(oCoordinates[0]);
93     unsigned int &bo = *(oCoordinates[1]);
94     unsigned int &co = *(oCoordinates[2]);
95     unsigned int &ho = *(oCoordinates[3]);
96     unsigned int &wo = *(oCoordinates[4]);
97 
98     // Stack tensors
99     for(; i < iNumTensors; ++(i))
100     {
101         for(bi = 0; bi < iBatchSize; ++(bi))
102         {
103             for(ci = 0; ci < iChannels; ++(ci))
104             {
105                 for(hi = 0; hi < iHeight; ++(hi))
106                 {
107                     for(wi = 0; wi < iWidth; ++(wi))
108                     {
109                         output[o  * oWidth * oHeight * oChannels * oBatchSize +
110                                bo * oWidth * oHeight * oChannels +
111                                co * oWidth * oHeight +
112                                ho * oWidth +
113                                wo];
114 
115                         output.Set(inputs[i]->Get());
116 
117                         ++(*(inputs[i]));
118                     }
119                 }
120             }
121         }
122     }
123 }
124 
125 } // namespace armnn
126