1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 #pragma once 6 7 #include "Optimization.hpp" 8 9 namespace armnn 10 { 11 namespace optimizations 12 { 13 14 class OptimizeConsecutiveReshapesImpl 15 { 16 public: 17 /// Run for every connection between a base ReshapeLayer and a child ReshapeLayer. 18 /// Inserts an equivalent ReshapeLayer that bypasses both for that connection. Run(Graph & graph,InputSlot & connection) const19 void Run(Graph& graph, InputSlot& connection) const 20 { 21 Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer(); 22 Layer& child = connection.GetOwningLayer(); 23 24 ARMNN_ASSERT(base.GetType() == LayerType::Reshape); 25 ARMNN_ASSERT(child.GetType() == LayerType::Reshape); 26 27 OutputSlot* parentOut = base.GetInputSlot(0).GetConnectedOutputSlot(); 28 29 const TensorInfo& inInfo = parentOut->GetTensorInfo(); 30 const TensorInfo& outInfo = child.GetOutputHandler().GetTensorInfo(); 31 32 // This Optimization is only appropriate when the base ReshapeLayer is connected to the child ReshapeLayer 33 // and no other Layer. 34 if (base.GetOutputSlot(0).GetNumConnections() > 1) 35 { 36 return; 37 } 38 39 if (inInfo.GetShape() != outInfo.GetShape()) 40 { 41 // Inserts equivalent reshape before base layer. 42 const std::string name = std::string("merged-") + base.GetName() + std::string("-with-") + child.GetName(); 43 const ReshapeDescriptor descriptor{outInfo.GetShape()}; 44 auto& newReshape = *graph.InsertNewLayer<ReshapeLayer>(base.GetInputSlot(0), descriptor, name.c_str()); 45 46 // Parent is now the new layer. 47 parentOut = &newReshape.GetOutputSlot(); 48 } 49 50 // Moves connections in child output to parent layer. 51 // Child layer will be removed as it's left unconnected. 52 // Base layer will be removed if left unconnected. 53 child.GetOutputSlot().MoveAllConnections(*parentOut); 54 } 55 56 protected: 57 OptimizeConsecutiveReshapesImpl() = default; 58 ~OptimizeConsecutiveReshapesImpl() = default; 59 }; 60 61 using OptimizeConsecutiveReshapes = OptimizeForConnection<ReshapeLayer, ReshapeLayer, OptimizeConsecutiveReshapesImpl>; 62 63 } // namespace optimizations 64 } // namespace armnn 65