1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2019 Arm Ltd. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker #pragma once 6*89c4ff92SAndroid Build Coastguard Worker 7*89c4ff92SAndroid Build Coastguard Worker #include "Optimization.hpp" 8*89c4ff92SAndroid Build Coastguard Worker 9*89c4ff92SAndroid Build Coastguard Worker namespace armnn 10*89c4ff92SAndroid Build Coastguard Worker { 11*89c4ff92SAndroid Build Coastguard Worker namespace optimizations 12*89c4ff92SAndroid Build Coastguard Worker { 13*89c4ff92SAndroid Build Coastguard Worker 14*89c4ff92SAndroid Build Coastguard Worker /// Replaces Permute leading into BatchToSpace with a DepthToSpace 15*89c4ff92SAndroid Build Coastguard Worker /// in the case where the Permute swaps the batch and channels dimensions 16*89c4ff92SAndroid Build Coastguard Worker /// such that the replacement is valid. 17*89c4ff92SAndroid Build Coastguard Worker template <typename PermuteType> 18*89c4ff92SAndroid Build Coastguard Worker class PermuteAndBatchToSpaceAsDepthToSpaceImpl 19*89c4ff92SAndroid Build Coastguard Worker { 20*89c4ff92SAndroid Build Coastguard Worker public: Run(Graph & graph,InputSlot & connection) const21*89c4ff92SAndroid Build Coastguard Worker void Run(Graph& graph, InputSlot& connection) const 22*89c4ff92SAndroid Build Coastguard Worker { 23*89c4ff92SAndroid Build Coastguard Worker // Validate base layer (the Permute) is compatible 24*89c4ff92SAndroid Build Coastguard Worker Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer(); 25*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(base.GetType() == LayerType::Permute || base.GetType() == LayerType::Transpose); 26*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& inputInfo = base.GetInputSlot(0).GetConnection()->GetTensorInfo(); 27*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& intermediateInfo = base.GetOutputSlot(0).GetTensorInfo(); 28*89c4ff92SAndroid Build Coastguard Worker if (intermediateInfo.GetNumDimensions() != 4) 29*89c4ff92SAndroid Build Coastguard Worker { 30*89c4ff92SAndroid Build Coastguard Worker // Must be 4D, otherwise the below checks do not make sense 31*89c4ff92SAndroid Build Coastguard Worker return; 32*89c4ff92SAndroid Build Coastguard Worker } 33*89c4ff92SAndroid Build Coastguard Worker if (!static_cast<PermuteType&>(base).GetParameters().m_DimMappings.IsEqual(PermutationVector{ 3, 1, 2, 0 })) 34*89c4ff92SAndroid Build Coastguard Worker { 35*89c4ff92SAndroid Build Coastguard Worker // Must swap batch and channels dimensions, otherwise it is not the (original) channels dimension 36*89c4ff92SAndroid Build Coastguard Worker // that is being decomposed. 37*89c4ff92SAndroid Build Coastguard Worker return; 38*89c4ff92SAndroid Build Coastguard Worker } 39*89c4ff92SAndroid Build Coastguard Worker 40*89c4ff92SAndroid Build Coastguard Worker // Validate child layer (the BatchToSpace) is compatible 41*89c4ff92SAndroid Build Coastguard Worker Layer& child = connection.GetOwningLayer(); 42*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(child.GetType() == LayerType::BatchToSpaceNd); 43*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outputInfo = child.GetOutputSlot(0).GetTensorInfo(); 44*89c4ff92SAndroid Build Coastguard Worker const BatchToSpaceNdDescriptor& batchToSpaceDesc = static_cast<BatchToSpaceNdLayer&>(child).GetParameters(); 45*89c4ff92SAndroid Build Coastguard Worker if (batchToSpaceDesc.m_DataLayout != DataLayout::NHWC) 46*89c4ff92SAndroid Build Coastguard Worker { 47*89c4ff92SAndroid Build Coastguard Worker // The rest of this function assumes NHWC, although in future this restriction could be lifted. 48*89c4ff92SAndroid Build Coastguard Worker return; 49*89c4ff92SAndroid Build Coastguard Worker } 50*89c4ff92SAndroid Build Coastguard Worker if (batchToSpaceDesc.m_Crops != std::vector<std::pair<unsigned int, unsigned int>>{ { 0, 0 }, { 0, 0 } }) 51*89c4ff92SAndroid Build Coastguard Worker { 52*89c4ff92SAndroid Build Coastguard Worker // Cropping is not supported in DepthToSpace 53*89c4ff92SAndroid Build Coastguard Worker return; 54*89c4ff92SAndroid Build Coastguard Worker } 55*89c4ff92SAndroid Build Coastguard Worker if (batchToSpaceDesc.m_BlockShape.size() != 2 || 56*89c4ff92SAndroid Build Coastguard Worker batchToSpaceDesc.m_BlockShape[0] != batchToSpaceDesc.m_BlockShape[1]) 57*89c4ff92SAndroid Build Coastguard Worker { 58*89c4ff92SAndroid Build Coastguard Worker // Asymmetric or non-2D block sizes are not supported by DepthToSpace 59*89c4ff92SAndroid Build Coastguard Worker return; 60*89c4ff92SAndroid Build Coastguard Worker } 61*89c4ff92SAndroid Build Coastguard Worker uint32_t blockSize = batchToSpaceDesc.m_BlockShape[0]; 62*89c4ff92SAndroid Build Coastguard Worker if (outputInfo.GetShape()[0] != 1 || outputInfo.GetShape()[3] != 1) 63*89c4ff92SAndroid Build Coastguard Worker { 64*89c4ff92SAndroid Build Coastguard Worker // The final output must have 1 batch and 1 channel because these dimensions will be swapped around 65*89c4ff92SAndroid Build Coastguard Worker // once we make the substitution, and it needs to be equivalent. 66*89c4ff92SAndroid Build Coastguard Worker return; 67*89c4ff92SAndroid Build Coastguard Worker } 68*89c4ff92SAndroid Build Coastguard Worker 69*89c4ff92SAndroid Build Coastguard Worker // Validate the intermediate tensor quantization params. 70*89c4ff92SAndroid Build Coastguard Worker // These must be identical to either the input or output quantization params, otherwise the intermediate tensor 71*89c4ff92SAndroid Build Coastguard Worker // may not have sufficient range/precision to preserve the values. 72*89c4ff92SAndroid Build Coastguard Worker // This would mean that once we perform the substitution this loss of precision will no longer occur, 73*89c4ff92SAndroid Build Coastguard Worker // so we would have changed the meaning of the network. 74*89c4ff92SAndroid Build Coastguard Worker bool isIntermediateQuantParamsSameAsInput = 75*89c4ff92SAndroid Build Coastguard Worker intermediateInfo.GetQuantizationScale() == inputInfo.GetQuantizationScale() && 76*89c4ff92SAndroid Build Coastguard Worker intermediateInfo.GetQuantizationOffset() == inputInfo.GetQuantizationOffset(); 77*89c4ff92SAndroid Build Coastguard Worker bool isIntermediateQuantParamsSameAsOutput = 78*89c4ff92SAndroid Build Coastguard Worker intermediateInfo.GetQuantizationScale() == outputInfo.GetQuantizationScale() && 79*89c4ff92SAndroid Build Coastguard Worker intermediateInfo.GetQuantizationOffset() == outputInfo.GetQuantizationOffset(); 80*89c4ff92SAndroid Build Coastguard Worker if (!isIntermediateQuantParamsSameAsInput && !isIntermediateQuantParamsSameAsOutput) 81*89c4ff92SAndroid Build Coastguard Worker { 82*89c4ff92SAndroid Build Coastguard Worker return; 83*89c4ff92SAndroid Build Coastguard Worker } 84*89c4ff92SAndroid Build Coastguard Worker 85*89c4ff92SAndroid Build Coastguard Worker // Insert equivalent DepthToSpace layer 86*89c4ff92SAndroid Build Coastguard Worker const std::string name = std::string("merged-") + base.GetName() + std::string("-with-") + child.GetName(); 87*89c4ff92SAndroid Build Coastguard Worker 88*89c4ff92SAndroid Build Coastguard Worker // Inserts equivalent reshape before base layer. 89*89c4ff92SAndroid Build Coastguard Worker const DepthToSpaceDescriptor depthToSpaceDesc(blockSize, DataLayout::NHWC); 90*89c4ff92SAndroid Build Coastguard Worker auto& depthToSpace = *graph.InsertNewLayer<DepthToSpaceLayer>(base.GetInputSlot(0), 91*89c4ff92SAndroid Build Coastguard Worker depthToSpaceDesc, 92*89c4ff92SAndroid Build Coastguard Worker name.c_str()); 93*89c4ff92SAndroid Build Coastguard Worker 94*89c4ff92SAndroid Build Coastguard Worker // Moves connections from child output to new layer. 95*89c4ff92SAndroid Build Coastguard Worker // Child layer will be removed as it's left unconnected. 96*89c4ff92SAndroid Build Coastguard Worker // Base layer will be removed if left unconnected. 97*89c4ff92SAndroid Build Coastguard Worker child.GetOutputSlot().MoveAllConnections(depthToSpace.GetOutputSlot()); 98*89c4ff92SAndroid Build Coastguard Worker } 99*89c4ff92SAndroid Build Coastguard Worker }; 100*89c4ff92SAndroid Build Coastguard Worker 101*89c4ff92SAndroid Build Coastguard Worker using PermuteAndBatchToSpaceAsDepthToSpace = OptimizeForConnection<PermuteLayer, BatchToSpaceNdLayer, 102*89c4ff92SAndroid Build Coastguard Worker PermuteAndBatchToSpaceAsDepthToSpaceImpl<PermuteLayer>>; 103*89c4ff92SAndroid Build Coastguard Worker using TransposeAndBatchToSpaceAsDepthToSpace = OptimizeForConnection<TransposeLayer, BatchToSpaceNdLayer, 104*89c4ff92SAndroid Build Coastguard Worker PermuteAndBatchToSpaceAsDepthToSpaceImpl<TransposeLayer>>; 105*89c4ff92SAndroid Build Coastguard Worker } // namespace optimizations 106*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn 107