xref: /aosp_15_r20/external/armnn/src/backends/reference/workloads/DepthToSpace.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2019 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "DepthToSpace.hpp"
7 
8 #include <armnnUtils/DataLayoutIndexed.hpp>
9 #include <armnnUtils/Permute.hpp>
10 
11 #include <armnn/utility/Assert.hpp>
12 
13 using namespace armnnUtils;
14 
15 namespace armnn
16 {
17 
DepthToSpace(const TensorInfo & inputInfo,const DepthToSpaceDescriptor & descriptor,const void * inputData,void * outputData,unsigned int dataTypeSize)18 void DepthToSpace(const TensorInfo& inputInfo,
19                   const DepthToSpaceDescriptor& descriptor,
20                   const void* inputData,
21                   void* outputData,
22                   unsigned int dataTypeSize)
23 {
24     const unsigned int blockSize = descriptor.m_BlockSize;
25     ARMNN_ASSERT(blockSize != 0u);
26 
27     const TensorShape& inputShape = inputInfo.GetShape();
28     const unsigned int batches = inputShape[0];
29 
30     armnnUtils::DataLayoutIndexed dataLayoutIndexed(descriptor.m_DataLayout);
31     const unsigned int inDepth  = inputShape[dataLayoutIndexed.GetChannelsIndex()];
32     const unsigned int inHeight = inputShape[dataLayoutIndexed.GetHeightIndex()];
33     const unsigned int inWidth  = inputShape[dataLayoutIndexed.GetWidthIndex()];
34 
35     const unsigned int outDepth = inDepth / (blockSize * blockSize);
36 
37     // The 4D input data can be interpreted as 6D (implicitly reshaped) as follows:
38     //
39     // [batch, block size, block size, inDepth, inHeight, inWidth] for NCHW and
40     // [batch, inHeight, inWidth, blockSize, blockSize, outDepth] for NHWC.
41     //
42     // DepthToSpace can then be implemented as a permutation in 6D resulting in
43     // the following shapes:
44     //
45     // [batch, outDepth, inHeight, blockSize, inWidth, blockSize] for NCHW and
46     // [batch, inHeight, blockSize, inWidth, blockSize, outDepth] for NHWC.
47     //
48     // NOTE:
49     // Since 6D tensors are not currently supported, in practice we need to handle each
50     // batch separately and execute 5D permutations
51 
52     TensorShape permDestShape;
53     PermutationVector permVector{};
54     if (descriptor.m_DataLayout == DataLayout::NCHW)
55     {
56         permDestShape = TensorShape({ outDepth, inHeight, blockSize, inWidth, blockSize });
57         permVector    = { 2, 4, 0, 1, 3 };
58     }
59     else
60     {
61         permDestShape = TensorShape({ inHeight, blockSize, inWidth, blockSize, outDepth });
62         permVector    = { 0, 2, 1, 3, 4 };
63     }
64 
65     const unsigned int numElementsPerBatch = inputShape.GetNumElements() / batches;
66 
67     for (unsigned int batchIndex = 0u; batchIndex < batches; ++batchIndex)
68     {
69         const uintptr_t batchDataOffset = batchIndex * (numElementsPerBatch * dataTypeSize);
70 
71         armnnUtils::Permute(permDestShape,
72                             permVector,
73                             static_cast<const void*>(reinterpret_cast<const uint8_t*>(inputData) + batchDataOffset),
74                             static_cast<void*>(reinterpret_cast<uint8_t*>(outputData) + batchDataOffset),
75                             dataTypeSize);
76     }
77 }
78 
79 } // namespace armnn
80