1 //
2 // Copyright © 2019 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "DepthToSpace.hpp"
7
8 #include <armnnUtils/DataLayoutIndexed.hpp>
9 #include <armnnUtils/Permute.hpp>
10
11 #include <armnn/utility/Assert.hpp>
12
13 using namespace armnnUtils;
14
15 namespace armnn
16 {
17
DepthToSpace(const TensorInfo & inputInfo,const DepthToSpaceDescriptor & descriptor,const void * inputData,void * outputData,unsigned int dataTypeSize)18 void DepthToSpace(const TensorInfo& inputInfo,
19 const DepthToSpaceDescriptor& descriptor,
20 const void* inputData,
21 void* outputData,
22 unsigned int dataTypeSize)
23 {
24 const unsigned int blockSize = descriptor.m_BlockSize;
25 ARMNN_ASSERT(blockSize != 0u);
26
27 const TensorShape& inputShape = inputInfo.GetShape();
28 const unsigned int batches = inputShape[0];
29
30 armnnUtils::DataLayoutIndexed dataLayoutIndexed(descriptor.m_DataLayout);
31 const unsigned int inDepth = inputShape[dataLayoutIndexed.GetChannelsIndex()];
32 const unsigned int inHeight = inputShape[dataLayoutIndexed.GetHeightIndex()];
33 const unsigned int inWidth = inputShape[dataLayoutIndexed.GetWidthIndex()];
34
35 const unsigned int outDepth = inDepth / (blockSize * blockSize);
36
37 // The 4D input data can be interpreted as 6D (implicitly reshaped) as follows:
38 //
39 // [batch, block size, block size, inDepth, inHeight, inWidth] for NCHW and
40 // [batch, inHeight, inWidth, blockSize, blockSize, outDepth] for NHWC.
41 //
42 // DepthToSpace can then be implemented as a permutation in 6D resulting in
43 // the following shapes:
44 //
45 // [batch, outDepth, inHeight, blockSize, inWidth, blockSize] for NCHW and
46 // [batch, inHeight, blockSize, inWidth, blockSize, outDepth] for NHWC.
47 //
48 // NOTE:
49 // Since 6D tensors are not currently supported, in practice we need to handle each
50 // batch separately and execute 5D permutations
51
52 TensorShape permDestShape;
53 PermutationVector permVector{};
54 if (descriptor.m_DataLayout == DataLayout::NCHW)
55 {
56 permDestShape = TensorShape({ outDepth, inHeight, blockSize, inWidth, blockSize });
57 permVector = { 2, 4, 0, 1, 3 };
58 }
59 else
60 {
61 permDestShape = TensorShape({ inHeight, blockSize, inWidth, blockSize, outDepth });
62 permVector = { 0, 2, 1, 3, 4 };
63 }
64
65 const unsigned int numElementsPerBatch = inputShape.GetNumElements() / batches;
66
67 for (unsigned int batchIndex = 0u; batchIndex < batches; ++batchIndex)
68 {
69 const uintptr_t batchDataOffset = batchIndex * (numElementsPerBatch * dataTypeSize);
70
71 armnnUtils::Permute(permDestShape,
72 permVector,
73 static_cast<const void*>(reinterpret_cast<const uint8_t*>(inputData) + batchDataOffset),
74 static_cast<void*>(reinterpret_cast<uint8_t*>(outputData) + batchDataOffset),
75 dataTypeSize);
76 }
77 }
78
79 } // namespace armnn
80