xref: /aosp_15_r20/external/armnn/src/armnn/optimizations/OptimizeConsecutiveReshapes.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "Optimization.hpp"
8 
9 namespace armnn
10 {
11 namespace optimizations
12 {
13 
14 class OptimizeConsecutiveReshapesImpl
15 {
16 public:
17     /// Run for every connection between a base ReshapeLayer and a child ReshapeLayer.
18     /// Inserts an equivalent ReshapeLayer that bypasses both for that connection.
Run(Graph & graph,InputSlot & connection) const19     void Run(Graph& graph, InputSlot& connection) const
20     {
21         Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer();
22         Layer& child = connection.GetOwningLayer();
23 
24         ARMNN_ASSERT(base.GetType() == LayerType::Reshape);
25         ARMNN_ASSERT(child.GetType() == LayerType::Reshape);
26 
27         OutputSlot* parentOut = base.GetInputSlot(0).GetConnectedOutputSlot();
28 
29         const TensorInfo& inInfo = parentOut->GetTensorInfo();
30         const TensorInfo& outInfo = child.GetOutputHandler().GetTensorInfo();
31 
32         // This Optimization is only appropriate when the base ReshapeLayer is connected to the child ReshapeLayer
33         // and no other Layer.
34         if (base.GetOutputSlot(0).GetNumConnections() > 1)
35         {
36             return;
37         }
38 
39         if (inInfo.GetShape() != outInfo.GetShape())
40         {
41             // Inserts equivalent reshape before base layer.
42             const std::string name = std::string("merged-") + base.GetName() + std::string("-with-") + child.GetName();
43             const ReshapeDescriptor descriptor{outInfo.GetShape()};
44             auto& newReshape = *graph.InsertNewLayer<ReshapeLayer>(base.GetInputSlot(0), descriptor, name.c_str());
45 
46             // Parent is now the new layer.
47             parentOut = &newReshape.GetOutputSlot();
48         }
49 
50         // Moves connections in child output to parent layer.
51         // Child layer will be removed as it's left unconnected.
52         // Base layer will be removed if left unconnected.
53         child.GetOutputSlot().MoveAllConnections(*parentOut);
54     }
55 
56 protected:
57     OptimizeConsecutiveReshapesImpl() = default;
58     ~OptimizeConsecutiveReshapesImpl() = default;
59 };
60 
61 using OptimizeConsecutiveReshapes = OptimizeForConnection<ReshapeLayer, ReshapeLayer, OptimizeConsecutiveReshapesImpl>;
62 
63 } // namespace optimizations
64 } // namespace armnn
65