xref: /aosp_15_r20/external/armnn/src/armnn/optimizations/PermuteAndBatchToSpaceAsDepthToSpace.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2019 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker 
7*89c4ff92SAndroid Build Coastguard Worker #include "Optimization.hpp"
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker namespace armnn
10*89c4ff92SAndroid Build Coastguard Worker {
11*89c4ff92SAndroid Build Coastguard Worker namespace optimizations
12*89c4ff92SAndroid Build Coastguard Worker {
13*89c4ff92SAndroid Build Coastguard Worker 
14*89c4ff92SAndroid Build Coastguard Worker /// Replaces Permute leading into BatchToSpace with a DepthToSpace
15*89c4ff92SAndroid Build Coastguard Worker /// in the case where the Permute swaps the batch and channels dimensions
16*89c4ff92SAndroid Build Coastguard Worker /// such that the replacement is valid.
17*89c4ff92SAndroid Build Coastguard Worker template <typename PermuteType>
18*89c4ff92SAndroid Build Coastguard Worker class PermuteAndBatchToSpaceAsDepthToSpaceImpl
19*89c4ff92SAndroid Build Coastguard Worker {
20*89c4ff92SAndroid Build Coastguard Worker public:
Run(Graph & graph,InputSlot & connection) const21*89c4ff92SAndroid Build Coastguard Worker     void Run(Graph& graph, InputSlot& connection) const
22*89c4ff92SAndroid Build Coastguard Worker     {
23*89c4ff92SAndroid Build Coastguard Worker         // Validate base layer (the Permute) is compatible
24*89c4ff92SAndroid Build Coastguard Worker         Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer();
25*89c4ff92SAndroid Build Coastguard Worker         ARMNN_ASSERT(base.GetType() == LayerType::Permute || base.GetType() == LayerType::Transpose);
26*89c4ff92SAndroid Build Coastguard Worker         const TensorInfo& inputInfo = base.GetInputSlot(0).GetConnection()->GetTensorInfo();
27*89c4ff92SAndroid Build Coastguard Worker         const TensorInfo& intermediateInfo = base.GetOutputSlot(0).GetTensorInfo();
28*89c4ff92SAndroid Build Coastguard Worker         if (intermediateInfo.GetNumDimensions() != 4)
29*89c4ff92SAndroid Build Coastguard Worker         {
30*89c4ff92SAndroid Build Coastguard Worker             // Must be 4D, otherwise the below checks do not make sense
31*89c4ff92SAndroid Build Coastguard Worker             return;
32*89c4ff92SAndroid Build Coastguard Worker         }
33*89c4ff92SAndroid Build Coastguard Worker         if (!static_cast<PermuteType&>(base).GetParameters().m_DimMappings.IsEqual(PermutationVector{ 3, 1, 2, 0 }))
34*89c4ff92SAndroid Build Coastguard Worker         {
35*89c4ff92SAndroid Build Coastguard Worker             // Must swap batch and channels dimensions, otherwise it is not the (original) channels dimension
36*89c4ff92SAndroid Build Coastguard Worker             // that is being decomposed.
37*89c4ff92SAndroid Build Coastguard Worker             return;
38*89c4ff92SAndroid Build Coastguard Worker         }
39*89c4ff92SAndroid Build Coastguard Worker 
40*89c4ff92SAndroid Build Coastguard Worker         // Validate child layer (the BatchToSpace) is compatible
41*89c4ff92SAndroid Build Coastguard Worker         Layer& child = connection.GetOwningLayer();
42*89c4ff92SAndroid Build Coastguard Worker         ARMNN_ASSERT(child.GetType() == LayerType::BatchToSpaceNd);
43*89c4ff92SAndroid Build Coastguard Worker         const TensorInfo& outputInfo = child.GetOutputSlot(0).GetTensorInfo();
44*89c4ff92SAndroid Build Coastguard Worker         const BatchToSpaceNdDescriptor& batchToSpaceDesc = static_cast<BatchToSpaceNdLayer&>(child).GetParameters();
45*89c4ff92SAndroid Build Coastguard Worker         if (batchToSpaceDesc.m_DataLayout != DataLayout::NHWC)
46*89c4ff92SAndroid Build Coastguard Worker         {
47*89c4ff92SAndroid Build Coastguard Worker             // The rest of this function assumes NHWC, although in future this restriction could be lifted.
48*89c4ff92SAndroid Build Coastguard Worker             return;
49*89c4ff92SAndroid Build Coastguard Worker         }
50*89c4ff92SAndroid Build Coastguard Worker         if (batchToSpaceDesc.m_Crops != std::vector<std::pair<unsigned int, unsigned int>>{ { 0, 0 }, { 0, 0 } })
51*89c4ff92SAndroid Build Coastguard Worker         {
52*89c4ff92SAndroid Build Coastguard Worker             // Cropping is not supported in DepthToSpace
53*89c4ff92SAndroid Build Coastguard Worker             return;
54*89c4ff92SAndroid Build Coastguard Worker         }
55*89c4ff92SAndroid Build Coastguard Worker         if (batchToSpaceDesc.m_BlockShape.size() != 2 ||
56*89c4ff92SAndroid Build Coastguard Worker         batchToSpaceDesc.m_BlockShape[0] != batchToSpaceDesc.m_BlockShape[1])
57*89c4ff92SAndroid Build Coastguard Worker         {
58*89c4ff92SAndroid Build Coastguard Worker             // Asymmetric or non-2D block sizes are not supported by DepthToSpace
59*89c4ff92SAndroid Build Coastguard Worker             return;
60*89c4ff92SAndroid Build Coastguard Worker         }
61*89c4ff92SAndroid Build Coastguard Worker         uint32_t blockSize = batchToSpaceDesc.m_BlockShape[0];
62*89c4ff92SAndroid Build Coastguard Worker         if (outputInfo.GetShape()[0] != 1 || outputInfo.GetShape()[3] != 1)
63*89c4ff92SAndroid Build Coastguard Worker         {
64*89c4ff92SAndroid Build Coastguard Worker             // The final output must have 1 batch and 1 channel because these dimensions will be swapped around
65*89c4ff92SAndroid Build Coastguard Worker             // once we make the substitution, and it needs to be equivalent.
66*89c4ff92SAndroid Build Coastguard Worker             return;
67*89c4ff92SAndroid Build Coastguard Worker         }
68*89c4ff92SAndroid Build Coastguard Worker 
69*89c4ff92SAndroid Build Coastguard Worker         // Validate the intermediate tensor quantization params.
70*89c4ff92SAndroid Build Coastguard Worker         // These must be identical to either the input or output quantization params, otherwise the intermediate tensor
71*89c4ff92SAndroid Build Coastguard Worker         // may not have sufficient range/precision to preserve the values.
72*89c4ff92SAndroid Build Coastguard Worker         // This would mean that once we perform the substitution this loss of precision will no longer occur,
73*89c4ff92SAndroid Build Coastguard Worker         // so we would have changed the meaning of the network.
74*89c4ff92SAndroid Build Coastguard Worker         bool isIntermediateQuantParamsSameAsInput =
75*89c4ff92SAndroid Build Coastguard Worker                 intermediateInfo.GetQuantizationScale() == inputInfo.GetQuantizationScale() &&
76*89c4ff92SAndroid Build Coastguard Worker                 intermediateInfo.GetQuantizationOffset() == inputInfo.GetQuantizationOffset();
77*89c4ff92SAndroid Build Coastguard Worker         bool isIntermediateQuantParamsSameAsOutput =
78*89c4ff92SAndroid Build Coastguard Worker                 intermediateInfo.GetQuantizationScale() == outputInfo.GetQuantizationScale() &&
79*89c4ff92SAndroid Build Coastguard Worker                 intermediateInfo.GetQuantizationOffset() == outputInfo.GetQuantizationOffset();
80*89c4ff92SAndroid Build Coastguard Worker         if (!isIntermediateQuantParamsSameAsInput && !isIntermediateQuantParamsSameAsOutput)
81*89c4ff92SAndroid Build Coastguard Worker         {
82*89c4ff92SAndroid Build Coastguard Worker             return;
83*89c4ff92SAndroid Build Coastguard Worker         }
84*89c4ff92SAndroid Build Coastguard Worker 
85*89c4ff92SAndroid Build Coastguard Worker         // Insert equivalent DepthToSpace layer
86*89c4ff92SAndroid Build Coastguard Worker         const std::string name = std::string("merged-") + base.GetName() + std::string("-with-") + child.GetName();
87*89c4ff92SAndroid Build Coastguard Worker 
88*89c4ff92SAndroid Build Coastguard Worker         // Inserts equivalent reshape before base layer.
89*89c4ff92SAndroid Build Coastguard Worker         const DepthToSpaceDescriptor depthToSpaceDesc(blockSize, DataLayout::NHWC);
90*89c4ff92SAndroid Build Coastguard Worker         auto& depthToSpace = *graph.InsertNewLayer<DepthToSpaceLayer>(base.GetInputSlot(0),
91*89c4ff92SAndroid Build Coastguard Worker                                                                       depthToSpaceDesc,
92*89c4ff92SAndroid Build Coastguard Worker                                                                       name.c_str());
93*89c4ff92SAndroid Build Coastguard Worker 
94*89c4ff92SAndroid Build Coastguard Worker         // Moves connections from child output to new layer.
95*89c4ff92SAndroid Build Coastguard Worker         // Child layer will be removed as it's left unconnected.
96*89c4ff92SAndroid Build Coastguard Worker         // Base layer will be removed if left unconnected.
97*89c4ff92SAndroid Build Coastguard Worker         child.GetOutputSlot().MoveAllConnections(depthToSpace.GetOutputSlot());
98*89c4ff92SAndroid Build Coastguard Worker     }
99*89c4ff92SAndroid Build Coastguard Worker };
100*89c4ff92SAndroid Build Coastguard Worker 
101*89c4ff92SAndroid Build Coastguard Worker using PermuteAndBatchToSpaceAsDepthToSpace = OptimizeForConnection<PermuteLayer, BatchToSpaceNdLayer,
102*89c4ff92SAndroid Build Coastguard Worker     PermuteAndBatchToSpaceAsDepthToSpaceImpl<PermuteLayer>>;
103*89c4ff92SAndroid Build Coastguard Worker using TransposeAndBatchToSpaceAsDepthToSpace = OptimizeForConnection<TransposeLayer, BatchToSpaceNdLayer,
104*89c4ff92SAndroid Build Coastguard Worker     PermuteAndBatchToSpaceAsDepthToSpaceImpl<TransposeLayer>>;
105*89c4ff92SAndroid Build Coastguard Worker }    // namespace optimizations
106*89c4ff92SAndroid Build Coastguard Worker }    // namespace armnn
107