1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include <optimizations/FoldPadIntoLayer2d.hpp>
9*89c4ff92SAndroid Build Coastguard Worker
10*89c4ff92SAndroid Build Coastguard Worker namespace armnn
11*89c4ff92SAndroid Build Coastguard Worker {
12*89c4ff92SAndroid Build Coastguard Worker
13*89c4ff92SAndroid Build Coastguard Worker namespace
14*89c4ff92SAndroid Build Coastguard Worker {
15*89c4ff92SAndroid Build Coastguard Worker
16*89c4ff92SAndroid Build Coastguard Worker //
17*89c4ff92SAndroid Build Coastguard Worker // this helper only works if all layers where the inputs connect to are not selected
18*89c4ff92SAndroid Build Coastguard Worker //
19*89c4ff92SAndroid Build Coastguard Worker
CreateIInputsFrom(const std::vector<armnn::IConnectableLayer * > & layers)20*89c4ff92SAndroid Build Coastguard Worker SubgraphView::IInputSlots CreateIInputsFrom(const std::vector<armnn::IConnectableLayer*>& layers)
21*89c4ff92SAndroid Build Coastguard Worker {
22*89c4ff92SAndroid Build Coastguard Worker SubgraphView::IInputSlots result;
23*89c4ff92SAndroid Build Coastguard Worker for (auto&& layer : layers)
24*89c4ff92SAndroid Build Coastguard Worker {
25*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0 ; i < layer->GetNumInputSlots(); ++i)
26*89c4ff92SAndroid Build Coastguard Worker {
27*89c4ff92SAndroid Build Coastguard Worker result.push_back(&(layer->GetInputSlot(i)));
28*89c4ff92SAndroid Build Coastguard Worker }
29*89c4ff92SAndroid Build Coastguard Worker }
30*89c4ff92SAndroid Build Coastguard Worker return result;
31*89c4ff92SAndroid Build Coastguard Worker }
32*89c4ff92SAndroid Build Coastguard Worker
33*89c4ff92SAndroid Build Coastguard Worker //
34*89c4ff92SAndroid Build Coastguard Worker // this helper only works if all layers where the outputs connect to are not selected
35*89c4ff92SAndroid Build Coastguard Worker //
36*89c4ff92SAndroid Build Coastguard Worker
CreateIOutputsFrom(const std::vector<armnn::IConnectableLayer * > & layers)37*89c4ff92SAndroid Build Coastguard Worker SubgraphView::IOutputSlots CreateIOutputsFrom(const std::vector<armnn::IConnectableLayer*>& layers)
38*89c4ff92SAndroid Build Coastguard Worker {
39*89c4ff92SAndroid Build Coastguard Worker SubgraphView::IOutputSlots result;
40*89c4ff92SAndroid Build Coastguard Worker for (auto &&layer: layers)
41*89c4ff92SAndroid Build Coastguard Worker {
42*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < layer->GetNumOutputSlots(); ++i)
43*89c4ff92SAndroid Build Coastguard Worker {
44*89c4ff92SAndroid Build Coastguard Worker result.push_back(&(layer->GetOutputSlot(i)));
45*89c4ff92SAndroid Build Coastguard Worker }
46*89c4ff92SAndroid Build Coastguard Worker }
47*89c4ff92SAndroid Build Coastguard Worker return result;
48*89c4ff92SAndroid Build Coastguard Worker }
49*89c4ff92SAndroid Build Coastguard Worker
50*89c4ff92SAndroid Build Coastguard Worker }
51*89c4ff92SAndroid Build Coastguard Worker
ReportUntouchedLayers(OptimizationViews & optimizationViews,std::map<LayerGuid,Layer * > untouched)52*89c4ff92SAndroid Build Coastguard Worker inline void ReportUntouchedLayers(OptimizationViews& optimizationViews, std::map<LayerGuid, Layer*> untouched)
53*89c4ff92SAndroid Build Coastguard Worker {
54*89c4ff92SAndroid Build Coastguard Worker std::vector<Layer*> untouchedVector;
55*89c4ff92SAndroid Build Coastguard Worker for (const auto& pair : untouched)
56*89c4ff92SAndroid Build Coastguard Worker {
57*89c4ff92SAndroid Build Coastguard Worker Layer* layer = pair.second;
58*89c4ff92SAndroid Build Coastguard Worker SubgraphView subgraphView({layer},
59*89c4ff92SAndroid Build Coastguard Worker CreateIInputsFrom({layer}),
60*89c4ff92SAndroid Build Coastguard Worker CreateIOutputsFrom({layer}));
61*89c4ff92SAndroid Build Coastguard Worker optimizationViews.AddUntouchedSubgraph(std::move(subgraphView));
62*89c4ff92SAndroid Build Coastguard Worker }
63*89c4ff92SAndroid Build Coastguard Worker }
64*89c4ff92SAndroid Build Coastguard Worker
65*89c4ff92SAndroid Build Coastguard Worker template<typename LayerType>
FoldPadLayer(OptimizationViews & optimizationViews,LayerType * baseLayer,LayerType * replacementLayer,PadLayer * padLayer)66*89c4ff92SAndroid Build Coastguard Worker LayerType* FoldPadLayer(OptimizationViews& optimizationViews,
67*89c4ff92SAndroid Build Coastguard Worker LayerType* baseLayer,
68*89c4ff92SAndroid Build Coastguard Worker LayerType* replacementLayer,
69*89c4ff92SAndroid Build Coastguard Worker PadLayer* padLayer)
70*89c4ff92SAndroid Build Coastguard Worker {
71*89c4ff92SAndroid Build Coastguard Worker SubgraphView substitutionSubgraph({padLayer, baseLayer},
72*89c4ff92SAndroid Build Coastguard Worker CreateIInputsFrom({padLayer}),
73*89c4ff92SAndroid Build Coastguard Worker CreateIOutputsFrom({baseLayer}));
74*89c4ff92SAndroid Build Coastguard Worker SubgraphView replacementSubgraph(replacementLayer);
75*89c4ff92SAndroid Build Coastguard Worker
76*89c4ff92SAndroid Build Coastguard Worker optimizationViews.AddSubstitution({substitutionSubgraph, replacementSubgraph});
77*89c4ff92SAndroid Build Coastguard Worker
78*89c4ff92SAndroid Build Coastguard Worker return replacementLayer;
79*89c4ff92SAndroid Build Coastguard Worker }
80*89c4ff92SAndroid Build Coastguard Worker
81*89c4ff92SAndroid Build Coastguard Worker template<typename LayerType>
FoldPadIntoAveragePool2d(OptimizationViews & optimizationViews,Pooling2dLayer * baseLayer,Pooling2dDescriptor & poolDescriptor,PadLayer * padLayer)82*89c4ff92SAndroid Build Coastguard Worker LayerType* FoldPadIntoAveragePool2d(OptimizationViews& optimizationViews,
83*89c4ff92SAndroid Build Coastguard Worker Pooling2dLayer* baseLayer,
84*89c4ff92SAndroid Build Coastguard Worker Pooling2dDescriptor& poolDescriptor,
85*89c4ff92SAndroid Build Coastguard Worker PadLayer* padLayer)
86*89c4ff92SAndroid Build Coastguard Worker {
87*89c4ff92SAndroid Build Coastguard Worker IConnectableLayer* replacement =
88*89c4ff92SAndroid Build Coastguard Worker optimizationViews.GetINetwork()->AddPooling2dLayer(poolDescriptor, "folded-pad-into-pool2d");
89*89c4ff92SAndroid Build Coastguard Worker LayerType* replacementLayer = PolymorphicDowncast<LayerType*>(replacement);
90*89c4ff92SAndroid Build Coastguard Worker
91*89c4ff92SAndroid Build Coastguard Worker FoldPadLayer(optimizationViews,
92*89c4ff92SAndroid Build Coastguard Worker baseLayer,
93*89c4ff92SAndroid Build Coastguard Worker replacementLayer,
94*89c4ff92SAndroid Build Coastguard Worker padLayer);
95*89c4ff92SAndroid Build Coastguard Worker
96*89c4ff92SAndroid Build Coastguard Worker return replacementLayer;
97*89c4ff92SAndroid Build Coastguard Worker }
98*89c4ff92SAndroid Build Coastguard Worker
99*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn
100