xref: /aosp_15_r20/external/armnn/src/armnnUtils/Permute.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <armnn/Tensor.hpp>
7 
8 #include <armnnUtils/Permute.hpp>
9 
10 #include "Half.hpp"
11 
12 #include <cassert>
13 #include <cstring>
14 
15 namespace
16 {
17 
18 class PermuteLoop
19 {
20 public:
21     using size_type = unsigned int;
22 
PermuteLoop(const armnn::TensorShape & dstShape,const armnn::PermutationVector & mappings)23     PermuteLoop(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings)
24         : m_DstShape(dstShape)
25     {
26         assert(dstShape.GetNumDimensions() == mappings.GetSize());
27 
28         const size_type numDims = dstShape.GetNumDimensions();
29 
30         size_type srcStride = 1U;
31         size_type dstStride = 1U;
32 
33         for (size_type i = numDims - 1U, k = 0U; k < numDims; ++k, --i)
34         {
35             m_SrcStrides[mappings[i]] = srcStride;
36             m_DstStrides[i] = dstStride;
37 
38             srcStride *= dstShape[mappings[i]];
39             dstStride *= dstShape[i];
40         }
41     }
42 
Unroll(const void * srcData,void * dstData,size_t dataTypeSize)43     void Unroll(const void* srcData, void* dstData, size_t dataTypeSize)
44     {
45         assert(srcData);
46         assert(dstData);
47         assert(dataTypeSize > 0);
48 
49         const unsigned char* srcDataPtr = reinterpret_cast<const unsigned char*>(srcData);
50         unsigned char* dstDataPtr       = reinterpret_cast<unsigned char*>(dstData);
51 
52         const unsigned char* const srcEndPtr = srcDataPtr + m_DstShape.GetNumElements() * dataTypeSize;
53         unsigned char* const       dstEndPtr = dstDataPtr + m_DstShape.GetNumElements() * dataTypeSize;
54 
55         Unroll(0, srcDataPtr, dstDataPtr, srcEndPtr, dstEndPtr, dataTypeSize);
56     }
57 
58 private:
Unroll(size_type dimension,const unsigned char * srcData,unsigned char * dstData,const unsigned char * srcEnd,unsigned char * dstEnd,size_t dataTypeSize)59     void Unroll(size_type dimension,
60                 const unsigned char* srcData, unsigned char* dstData,
61                 const unsigned char* srcEnd, unsigned char* dstEnd,
62                 size_t dataTypeSize)
63     {
64         assert(srcData);
65         assert(dstData);
66         assert(srcEnd);
67         assert(dstEnd);
68         assert(srcData < srcEnd);
69         assert(dstData < dstEnd);
70         assert(dataTypeSize > 0);
71 
72         if (dimension >= m_DstShape.GetNumDimensions())
73         {
74             ::memcpy(dstData, srcData, dataTypeSize);
75         }
76         else
77         {
78             for (size_type i = 0; i < m_DstShape[dimension]; i++)
79             {
80                 Unroll(dimension + 1, srcData, dstData, srcEnd, dstEnd, dataTypeSize);
81 
82                 srcData += m_SrcStrides[dimension] * dataTypeSize;
83                 dstData += m_DstStrides[dimension] * dataTypeSize;
84             }
85         }
86     }
87 
88     armnn::TensorShape m_DstShape;
89     std::array<size_type, armnn::MaxNumOfTensorDimensions> m_SrcStrides;
90     std::array<size_type, armnn::MaxNumOfTensorDimensions> m_DstStrides;
91 };
92 
93 } // namespace
94 
95 namespace armnnUtils
96 {
97 
Permuted(const armnn::TensorShape & srcShape,const armnn::PermutationVector & mappings)98 armnn::TensorShape Permuted(const armnn::TensorShape& srcShape,
99                             const armnn::PermutationVector& mappings)
100 {
101     assert(srcShape.GetNumDimensions() == mappings.GetSize());
102 
103     const unsigned int numDims = mappings.GetSize();
104     unsigned int outDims[armnn::MaxNumOfTensorDimensions];
105 
106     for (unsigned int i = 0U; i < numDims; ++i)
107     {
108         outDims[mappings[i]] = srcShape[i];
109     }
110 
111     armnn::TensorShape permutedShape(numDims, outDims);
112     return permutedShape;
113 }
114 
Permuted(const armnn::TensorInfo & info,const armnn::PermutationVector & mappings)115 armnn::TensorInfo Permuted(const armnn::TensorInfo& info,
116                            const armnn::PermutationVector& mappings)
117 {
118     armnn::TensorInfo outInfo(info);
119     outInfo.SetShape(Permuted(info.GetShape(), mappings));
120 
121     // If TensorInfo has Per-Axis Quantization then it also has a QuantizationDim which needs to
122     // be permuted according to the mapping
123     if (info.GetQuantizationDim().has_value())
124     {
125         outInfo.SetQuantizationDim(mappings[info.GetQuantizationDim().value()]);
126     }
127 
128     return outInfo;
129 }
130 
Permute(const armnn::TensorShape & dstShape,const armnn::PermutationVector & mappings,const void * src,void * dst,size_t dataTypeSize)131 void Permute(const armnn::TensorShape& dstShape, const armnn::PermutationVector& mappings,
132              const void* src, void* dst, size_t dataTypeSize)
133 {
134     PermuteLoop(dstShape, mappings).Unroll(src, dst, dataTypeSize);
135 }
136 
137 } // namespace armnnUtils
138