xref: /aosp_15_r20/external/armnn/src/armnnUtils/Transpose.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Tensor.hpp>
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Transpose.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include "Half.hpp"
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker #include <cassert>
13*89c4ff92SAndroid Build Coastguard Worker #include <cstring>
14*89c4ff92SAndroid Build Coastguard Worker 
15*89c4ff92SAndroid Build Coastguard Worker namespace
16*89c4ff92SAndroid Build Coastguard Worker {
17*89c4ff92SAndroid Build Coastguard Worker 
18*89c4ff92SAndroid Build Coastguard Worker class TransposeLoop
19*89c4ff92SAndroid Build Coastguard Worker {
20*89c4ff92SAndroid Build Coastguard Worker public:
21*89c4ff92SAndroid Build Coastguard Worker     using size_type = unsigned int;
22*89c4ff92SAndroid Build Coastguard Worker 
TransposeLoop(const armnn::TensorShape & srcShape,const armnn::PermutationVector & mappings)23*89c4ff92SAndroid Build Coastguard Worker     TransposeLoop(const armnn::TensorShape& srcShape, const armnn::PermutationVector& mappings)
24*89c4ff92SAndroid Build Coastguard Worker         : m_SrcShape(srcShape)
25*89c4ff92SAndroid Build Coastguard Worker     {
26*89c4ff92SAndroid Build Coastguard Worker         assert(srcShape.GetNumDimensions() == mappings.GetSize());
27*89c4ff92SAndroid Build Coastguard Worker 
28*89c4ff92SAndroid Build Coastguard Worker         const size_type numDims = srcShape.GetNumDimensions();
29*89c4ff92SAndroid Build Coastguard Worker 
30*89c4ff92SAndroid Build Coastguard Worker         size_type srcStride = 1U;
31*89c4ff92SAndroid Build Coastguard Worker         size_type dstStride = 1U;
32*89c4ff92SAndroid Build Coastguard Worker 
33*89c4ff92SAndroid Build Coastguard Worker         for (size_type i = numDims - 1U, k = 0U; k < numDims; ++k, --i)
34*89c4ff92SAndroid Build Coastguard Worker         {
35*89c4ff92SAndroid Build Coastguard Worker             m_SrcStrides[i] = srcStride;
36*89c4ff92SAndroid Build Coastguard Worker             m_DstStrides[mappings[i]] = dstStride;
37*89c4ff92SAndroid Build Coastguard Worker 
38*89c4ff92SAndroid Build Coastguard Worker             srcStride *= srcShape[i];
39*89c4ff92SAndroid Build Coastguard Worker             dstStride *= srcShape[mappings[i]];
40*89c4ff92SAndroid Build Coastguard Worker         }
41*89c4ff92SAndroid Build Coastguard Worker     }
42*89c4ff92SAndroid Build Coastguard Worker 
Unroll(const void * srcData,void * dstData,size_t dataTypeSize)43*89c4ff92SAndroid Build Coastguard Worker     void Unroll(const void* srcData, void* dstData, size_t dataTypeSize)
44*89c4ff92SAndroid Build Coastguard Worker     {
45*89c4ff92SAndroid Build Coastguard Worker         assert(srcData);
46*89c4ff92SAndroid Build Coastguard Worker         assert(dstData);
47*89c4ff92SAndroid Build Coastguard Worker         assert(dataTypeSize > 0);
48*89c4ff92SAndroid Build Coastguard Worker 
49*89c4ff92SAndroid Build Coastguard Worker         const unsigned char* srcDataPtr = reinterpret_cast<const unsigned char*>(srcData);
50*89c4ff92SAndroid Build Coastguard Worker         unsigned char* dstDataPtr       = reinterpret_cast<unsigned char*>(dstData);
51*89c4ff92SAndroid Build Coastguard Worker 
52*89c4ff92SAndroid Build Coastguard Worker         const unsigned char* const srcEndPtr = srcDataPtr + m_SrcShape.GetNumElements() * dataTypeSize;
53*89c4ff92SAndroid Build Coastguard Worker         unsigned char* const       dstEndPtr = dstDataPtr + m_SrcShape.GetNumElements() * dataTypeSize;
54*89c4ff92SAndroid Build Coastguard Worker 
55*89c4ff92SAndroid Build Coastguard Worker         Unroll(0, srcDataPtr, dstDataPtr, srcEndPtr, dstEndPtr, dataTypeSize);
56*89c4ff92SAndroid Build Coastguard Worker     }
57*89c4ff92SAndroid Build Coastguard Worker 
58*89c4ff92SAndroid Build Coastguard Worker private:
Unroll(size_type dimension,const unsigned char * srcData,unsigned char * dstData,const unsigned char * srcEnd,unsigned char * dstEnd,size_t dataTypeSize)59*89c4ff92SAndroid Build Coastguard Worker     void Unroll(size_type dimension,
60*89c4ff92SAndroid Build Coastguard Worker                 const unsigned char* srcData, unsigned char* dstData,
61*89c4ff92SAndroid Build Coastguard Worker                 const unsigned char* srcEnd, unsigned char* dstEnd,
62*89c4ff92SAndroid Build Coastguard Worker                 size_t dataTypeSize)
63*89c4ff92SAndroid Build Coastguard Worker     {
64*89c4ff92SAndroid Build Coastguard Worker         assert(srcData);
65*89c4ff92SAndroid Build Coastguard Worker         assert(dstData);
66*89c4ff92SAndroid Build Coastguard Worker         assert(srcEnd);
67*89c4ff92SAndroid Build Coastguard Worker         assert(dstEnd);
68*89c4ff92SAndroid Build Coastguard Worker         assert(srcData < srcEnd);
69*89c4ff92SAndroid Build Coastguard Worker         assert(dstData < dstEnd);
70*89c4ff92SAndroid Build Coastguard Worker         assert(dataTypeSize > 0);
71*89c4ff92SAndroid Build Coastguard Worker 
72*89c4ff92SAndroid Build Coastguard Worker         if (dimension >= m_SrcShape.GetNumDimensions())
73*89c4ff92SAndroid Build Coastguard Worker         {
74*89c4ff92SAndroid Build Coastguard Worker             ::memcpy(dstData, srcData, dataTypeSize);
75*89c4ff92SAndroid Build Coastguard Worker         }
76*89c4ff92SAndroid Build Coastguard Worker         else
77*89c4ff92SAndroid Build Coastguard Worker         {
78*89c4ff92SAndroid Build Coastguard Worker             for (size_type i = 0; i < m_SrcShape[dimension]; i++)
79*89c4ff92SAndroid Build Coastguard Worker             {
80*89c4ff92SAndroid Build Coastguard Worker                 Unroll(dimension + 1, srcData, dstData, srcEnd, dstEnd, dataTypeSize);
81*89c4ff92SAndroid Build Coastguard Worker 
82*89c4ff92SAndroid Build Coastguard Worker                 srcData += m_SrcStrides[dimension] * dataTypeSize;
83*89c4ff92SAndroid Build Coastguard Worker                 dstData += m_DstStrides[dimension] * dataTypeSize;
84*89c4ff92SAndroid Build Coastguard Worker             }
85*89c4ff92SAndroid Build Coastguard Worker         }
86*89c4ff92SAndroid Build Coastguard Worker     }
87*89c4ff92SAndroid Build Coastguard Worker 
88*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorShape m_SrcShape;
89*89c4ff92SAndroid Build Coastguard Worker     std::array<size_type, armnn::MaxNumOfTensorDimensions> m_SrcStrides;
90*89c4ff92SAndroid Build Coastguard Worker     std::array<size_type, armnn::MaxNumOfTensorDimensions> m_DstStrides;
91*89c4ff92SAndroid Build Coastguard Worker };
92*89c4ff92SAndroid Build Coastguard Worker 
93*89c4ff92SAndroid Build Coastguard Worker } // namespace
94*89c4ff92SAndroid Build Coastguard Worker 
95*89c4ff92SAndroid Build Coastguard Worker namespace armnnUtils
96*89c4ff92SAndroid Build Coastguard Worker {
97*89c4ff92SAndroid Build Coastguard Worker 
TransposeTensorShape(const armnn::TensorShape & srcShape,const armnn::PermutationVector & mappings)98*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape TransposeTensorShape(const armnn::TensorShape& srcShape, const armnn::PermutationVector& mappings)
99*89c4ff92SAndroid Build Coastguard Worker {
100*89c4ff92SAndroid Build Coastguard Worker     assert(srcShape.GetNumDimensions() == mappings.GetSize());
101*89c4ff92SAndroid Build Coastguard Worker 
102*89c4ff92SAndroid Build Coastguard Worker     const unsigned int numDims = mappings.GetSize();
103*89c4ff92SAndroid Build Coastguard Worker     unsigned int outDims[armnn::MaxNumOfTensorDimensions];
104*89c4ff92SAndroid Build Coastguard Worker 
105*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0U; i < numDims; ++i)
106*89c4ff92SAndroid Build Coastguard Worker     {
107*89c4ff92SAndroid Build Coastguard Worker         outDims[i] = srcShape[mappings[i]];
108*89c4ff92SAndroid Build Coastguard Worker     }
109*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorShape permutedShape(numDims, outDims);
110*89c4ff92SAndroid Build Coastguard Worker     return permutedShape;
111*89c4ff92SAndroid Build Coastguard Worker }
112*89c4ff92SAndroid Build Coastguard Worker 
TransposeTensorShape(const armnn::TensorInfo & info,const armnn::PermutationVector & mappings)113*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo TransposeTensorShape(const armnn::TensorInfo& info, const armnn::PermutationVector& mappings)
114*89c4ff92SAndroid Build Coastguard Worker {
115*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outInfo(info);
116*89c4ff92SAndroid Build Coastguard Worker     outInfo.SetShape(TransposeTensorShape(info.GetShape(), mappings));
117*89c4ff92SAndroid Build Coastguard Worker     return outInfo;
118*89c4ff92SAndroid Build Coastguard Worker }
119*89c4ff92SAndroid Build Coastguard Worker 
Transpose(const armnn::TensorShape & srcShape,const armnn::PermutationVector & mappings,const void * src,void * dst,size_t dataTypeSize)120*89c4ff92SAndroid Build Coastguard Worker void Transpose(const armnn::TensorShape& srcShape, const armnn::PermutationVector& mappings,
121*89c4ff92SAndroid Build Coastguard Worker              const void* src, void* dst, size_t dataTypeSize)
122*89c4ff92SAndroid Build Coastguard Worker {
123*89c4ff92SAndroid Build Coastguard Worker     TransposeLoop(srcShape, mappings).Unroll(src, dst, dataTypeSize);
124*89c4ff92SAndroid Build Coastguard Worker }
125*89c4ff92SAndroid Build Coastguard Worker 
126*89c4ff92SAndroid Build Coastguard Worker } // namespace armnnUtils
127