1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Tensor.hpp>
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Transpose.hpp>
9*89c4ff92SAndroid Build Coastguard Worker
10*89c4ff92SAndroid Build Coastguard Worker #include "Half.hpp"
11*89c4ff92SAndroid Build Coastguard Worker
12*89c4ff92SAndroid Build Coastguard Worker #include <cassert>
13*89c4ff92SAndroid Build Coastguard Worker #include <cstring>
14*89c4ff92SAndroid Build Coastguard Worker
15*89c4ff92SAndroid Build Coastguard Worker namespace
16*89c4ff92SAndroid Build Coastguard Worker {
17*89c4ff92SAndroid Build Coastguard Worker
18*89c4ff92SAndroid Build Coastguard Worker class TransposeLoop
19*89c4ff92SAndroid Build Coastguard Worker {
20*89c4ff92SAndroid Build Coastguard Worker public:
21*89c4ff92SAndroid Build Coastguard Worker using size_type = unsigned int;
22*89c4ff92SAndroid Build Coastguard Worker
TransposeLoop(const armnn::TensorShape & srcShape,const armnn::PermutationVector & mappings)23*89c4ff92SAndroid Build Coastguard Worker TransposeLoop(const armnn::TensorShape& srcShape, const armnn::PermutationVector& mappings)
24*89c4ff92SAndroid Build Coastguard Worker : m_SrcShape(srcShape)
25*89c4ff92SAndroid Build Coastguard Worker {
26*89c4ff92SAndroid Build Coastguard Worker assert(srcShape.GetNumDimensions() == mappings.GetSize());
27*89c4ff92SAndroid Build Coastguard Worker
28*89c4ff92SAndroid Build Coastguard Worker const size_type numDims = srcShape.GetNumDimensions();
29*89c4ff92SAndroid Build Coastguard Worker
30*89c4ff92SAndroid Build Coastguard Worker size_type srcStride = 1U;
31*89c4ff92SAndroid Build Coastguard Worker size_type dstStride = 1U;
32*89c4ff92SAndroid Build Coastguard Worker
33*89c4ff92SAndroid Build Coastguard Worker for (size_type i = numDims - 1U, k = 0U; k < numDims; ++k, --i)
34*89c4ff92SAndroid Build Coastguard Worker {
35*89c4ff92SAndroid Build Coastguard Worker m_SrcStrides[i] = srcStride;
36*89c4ff92SAndroid Build Coastguard Worker m_DstStrides[mappings[i]] = dstStride;
37*89c4ff92SAndroid Build Coastguard Worker
38*89c4ff92SAndroid Build Coastguard Worker srcStride *= srcShape[i];
39*89c4ff92SAndroid Build Coastguard Worker dstStride *= srcShape[mappings[i]];
40*89c4ff92SAndroid Build Coastguard Worker }
41*89c4ff92SAndroid Build Coastguard Worker }
42*89c4ff92SAndroid Build Coastguard Worker
Unroll(const void * srcData,void * dstData,size_t dataTypeSize)43*89c4ff92SAndroid Build Coastguard Worker void Unroll(const void* srcData, void* dstData, size_t dataTypeSize)
44*89c4ff92SAndroid Build Coastguard Worker {
45*89c4ff92SAndroid Build Coastguard Worker assert(srcData);
46*89c4ff92SAndroid Build Coastguard Worker assert(dstData);
47*89c4ff92SAndroid Build Coastguard Worker assert(dataTypeSize > 0);
48*89c4ff92SAndroid Build Coastguard Worker
49*89c4ff92SAndroid Build Coastguard Worker const unsigned char* srcDataPtr = reinterpret_cast<const unsigned char*>(srcData);
50*89c4ff92SAndroid Build Coastguard Worker unsigned char* dstDataPtr = reinterpret_cast<unsigned char*>(dstData);
51*89c4ff92SAndroid Build Coastguard Worker
52*89c4ff92SAndroid Build Coastguard Worker const unsigned char* const srcEndPtr = srcDataPtr + m_SrcShape.GetNumElements() * dataTypeSize;
53*89c4ff92SAndroid Build Coastguard Worker unsigned char* const dstEndPtr = dstDataPtr + m_SrcShape.GetNumElements() * dataTypeSize;
54*89c4ff92SAndroid Build Coastguard Worker
55*89c4ff92SAndroid Build Coastguard Worker Unroll(0, srcDataPtr, dstDataPtr, srcEndPtr, dstEndPtr, dataTypeSize);
56*89c4ff92SAndroid Build Coastguard Worker }
57*89c4ff92SAndroid Build Coastguard Worker
58*89c4ff92SAndroid Build Coastguard Worker private:
Unroll(size_type dimension,const unsigned char * srcData,unsigned char * dstData,const unsigned char * srcEnd,unsigned char * dstEnd,size_t dataTypeSize)59*89c4ff92SAndroid Build Coastguard Worker void Unroll(size_type dimension,
60*89c4ff92SAndroid Build Coastguard Worker const unsigned char* srcData, unsigned char* dstData,
61*89c4ff92SAndroid Build Coastguard Worker const unsigned char* srcEnd, unsigned char* dstEnd,
62*89c4ff92SAndroid Build Coastguard Worker size_t dataTypeSize)
63*89c4ff92SAndroid Build Coastguard Worker {
64*89c4ff92SAndroid Build Coastguard Worker assert(srcData);
65*89c4ff92SAndroid Build Coastguard Worker assert(dstData);
66*89c4ff92SAndroid Build Coastguard Worker assert(srcEnd);
67*89c4ff92SAndroid Build Coastguard Worker assert(dstEnd);
68*89c4ff92SAndroid Build Coastguard Worker assert(srcData < srcEnd);
69*89c4ff92SAndroid Build Coastguard Worker assert(dstData < dstEnd);
70*89c4ff92SAndroid Build Coastguard Worker assert(dataTypeSize > 0);
71*89c4ff92SAndroid Build Coastguard Worker
72*89c4ff92SAndroid Build Coastguard Worker if (dimension >= m_SrcShape.GetNumDimensions())
73*89c4ff92SAndroid Build Coastguard Worker {
74*89c4ff92SAndroid Build Coastguard Worker ::memcpy(dstData, srcData, dataTypeSize);
75*89c4ff92SAndroid Build Coastguard Worker }
76*89c4ff92SAndroid Build Coastguard Worker else
77*89c4ff92SAndroid Build Coastguard Worker {
78*89c4ff92SAndroid Build Coastguard Worker for (size_type i = 0; i < m_SrcShape[dimension]; i++)
79*89c4ff92SAndroid Build Coastguard Worker {
80*89c4ff92SAndroid Build Coastguard Worker Unroll(dimension + 1, srcData, dstData, srcEnd, dstEnd, dataTypeSize);
81*89c4ff92SAndroid Build Coastguard Worker
82*89c4ff92SAndroid Build Coastguard Worker srcData += m_SrcStrides[dimension] * dataTypeSize;
83*89c4ff92SAndroid Build Coastguard Worker dstData += m_DstStrides[dimension] * dataTypeSize;
84*89c4ff92SAndroid Build Coastguard Worker }
85*89c4ff92SAndroid Build Coastguard Worker }
86*89c4ff92SAndroid Build Coastguard Worker }
87*89c4ff92SAndroid Build Coastguard Worker
88*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape m_SrcShape;
89*89c4ff92SAndroid Build Coastguard Worker std::array<size_type, armnn::MaxNumOfTensorDimensions> m_SrcStrides;
90*89c4ff92SAndroid Build Coastguard Worker std::array<size_type, armnn::MaxNumOfTensorDimensions> m_DstStrides;
91*89c4ff92SAndroid Build Coastguard Worker };
92*89c4ff92SAndroid Build Coastguard Worker
93*89c4ff92SAndroid Build Coastguard Worker } // namespace
94*89c4ff92SAndroid Build Coastguard Worker
95*89c4ff92SAndroid Build Coastguard Worker namespace armnnUtils
96*89c4ff92SAndroid Build Coastguard Worker {
97*89c4ff92SAndroid Build Coastguard Worker
TransposeTensorShape(const armnn::TensorShape & srcShape,const armnn::PermutationVector & mappings)98*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape TransposeTensorShape(const armnn::TensorShape& srcShape, const armnn::PermutationVector& mappings)
99*89c4ff92SAndroid Build Coastguard Worker {
100*89c4ff92SAndroid Build Coastguard Worker assert(srcShape.GetNumDimensions() == mappings.GetSize());
101*89c4ff92SAndroid Build Coastguard Worker
102*89c4ff92SAndroid Build Coastguard Worker const unsigned int numDims = mappings.GetSize();
103*89c4ff92SAndroid Build Coastguard Worker unsigned int outDims[armnn::MaxNumOfTensorDimensions];
104*89c4ff92SAndroid Build Coastguard Worker
105*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0U; i < numDims; ++i)
106*89c4ff92SAndroid Build Coastguard Worker {
107*89c4ff92SAndroid Build Coastguard Worker outDims[i] = srcShape[mappings[i]];
108*89c4ff92SAndroid Build Coastguard Worker }
109*89c4ff92SAndroid Build Coastguard Worker armnn::TensorShape permutedShape(numDims, outDims);
110*89c4ff92SAndroid Build Coastguard Worker return permutedShape;
111*89c4ff92SAndroid Build Coastguard Worker }
112*89c4ff92SAndroid Build Coastguard Worker
TransposeTensorShape(const armnn::TensorInfo & info,const armnn::PermutationVector & mappings)113*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo TransposeTensorShape(const armnn::TensorInfo& info, const armnn::PermutationVector& mappings)
114*89c4ff92SAndroid Build Coastguard Worker {
115*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outInfo(info);
116*89c4ff92SAndroid Build Coastguard Worker outInfo.SetShape(TransposeTensorShape(info.GetShape(), mappings));
117*89c4ff92SAndroid Build Coastguard Worker return outInfo;
118*89c4ff92SAndroid Build Coastguard Worker }
119*89c4ff92SAndroid Build Coastguard Worker
Transpose(const armnn::TensorShape & srcShape,const armnn::PermutationVector & mappings,const void * src,void * dst,size_t dataTypeSize)120*89c4ff92SAndroid Build Coastguard Worker void Transpose(const armnn::TensorShape& srcShape, const armnn::PermutationVector& mappings,
121*89c4ff92SAndroid Build Coastguard Worker const void* src, void* dst, size_t dataTypeSize)
122*89c4ff92SAndroid Build Coastguard Worker {
123*89c4ff92SAndroid Build Coastguard Worker TransposeLoop(srcShape, mappings).Unroll(src, dst, dataTypeSize);
124*89c4ff92SAndroid Build Coastguard Worker }
125*89c4ff92SAndroid Build Coastguard Worker
126*89c4ff92SAndroid Build Coastguard Worker } // namespace armnnUtils
127