xref: /aosp_15_r20/external/armnn/src/armnn/SubgraphViewSelector.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017, 2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "SubgraphViewSelector.hpp"
7 #include "Graph.hpp"
8 
9 #include <armnn/utility/Assert.hpp>
10 #include <armnn/utility/IgnoreUnused.hpp>
11 #include <armnn/utility/PolymorphicDowncast.hpp>
12 
13 #include <algorithm>
14 #include <map>
15 #include <queue>
16 #include <unordered_set>
17 
18 namespace armnn
19 {
20 
21 namespace
22 {
23 
24 /// Intermediate data-structure to store the subgraph that a layer has been assigned to.
25 /// This is a "disjoint set" data structure that allows efficient merging of subgraphs,
26 /// which is a key part of the algorithm. Subgraphs are arranged in singly-linked trees
27 /// (with each node storing a pointer to its parent). Subgraphs in the same tree are considered
28 /// to have been merged. Merging subgraphs is performed by attaching one tree to another,
29 /// which is a simple pointer update.
30 ///
31 /// NOTE: Due to the way this is stored, it is almost never correct to directly compare pointers
32 /// to two PartialSubgraphs to check if two layers belong in the same subgraph. Instead you
33 /// should use IsMergedWith().
34 ///
35 /// This structure also stores information about the dependencies of each subgraph, which is needed
36 /// to determine whether certain subgraphs can be merged. Checking whether a subgraph
37 /// depends on another subgraph is a frequent operation in the algorithm (see AssignSplitId) and so this is optimized
38 /// in preference to the merging of subgraphs. This leads to an approach where each subgraph stores
39 /// a set of all the subgraphs it depends on (for a fast lookup). In order to efficiently update this
40 /// set as subgraphs are merged means we also store a set of subgraphs which *depend on us* (i.e. the
41 /// complement of our dependencies).
42 class PartialSubgraph
43 {
44 public:
45     /// If this subgraph has been merged with another then there is an agreed "representative" for the combined
46     /// subgraph, which uniquely identifies the subgraph.
GetRepresentative()47     PartialSubgraph* GetRepresentative()
48     {
49         // Recurse up the tree to find the root node.
50         if (m_Parent == nullptr)
51         {
52             return this;
53         }
54         else
55         {
56             PartialSubgraph* result = m_Parent->GetRepresentative();
57             // Update our parent pointer to point directly to the root in order to speed up future calls to this method.
58             // This essentially "flattens" the tree.
59             m_Parent = result;
60             return result;
61         }
62     }
63 
64     /// Merges this subgraph with another.
MergeWith(PartialSubgraph * other)65     void MergeWith(PartialSubgraph* other)
66     {
67         if (m_Parent == nullptr)
68         {
69             other = other->GetRepresentative();
70             if (this == other)
71             {
72                 // Already merged - no-op
73                 return;
74             }
75             m_Parent = other;
76 
77             // Update others' dependency sets to point to the new representative rather than us.
78             // Keeping these up-to-date means we can rely on these sets containing representatives when
79             // we perform a lookup in HasAntecedent() and so don't need to resolve the representative for each element
80             // of the set. See description at the top of this class for more rationale.
81             for (PartialSubgraph* a : m_Antecedents)
82             {
83                 size_t numErased = a->m_Dependants.erase(this);
84                 ARMNN_ASSERT(numErased == 1);
85                 IgnoreUnused(numErased);
86                 a->m_Dependants.insert(m_Parent);
87             }
88             for (PartialSubgraph* a : m_Dependants)
89             {
90                 size_t numErased = a->m_Antecedents.erase(this);
91                 ARMNN_ASSERT(numErased == 1);
92                 IgnoreUnused(numErased);
93                 a->m_Antecedents.insert(m_Parent);
94             }
95 
96             // Merge our dependency sets into our new representative.
97             // We no longer need to maintain our own sets, as requests will always be forwarded to the representative.
98             m_Parent->m_Antecedents.insert(m_Antecedents.begin(), m_Antecedents.end());
99             m_Antecedents.clear();
100             m_Parent->m_Dependants.insert(m_Dependants.begin(), m_Dependants.end());
101             m_Dependants.clear();
102         }
103         else
104         {
105             // Defer request to the representative
106             GetRepresentative()->MergeWith(other);
107         }
108     }
109 
110     /// Checks if this subgraph has been merged with the given subgraph.
IsMergedWith(PartialSubgraph * other)111     bool IsMergedWith(PartialSubgraph* other)
112     {
113         return GetRepresentative() == other->GetRepresentative();
114     }
115 
116     /// Marks the given subgraph as a direct antecedent (dependency) of this one.
AddDirectAntecedent(PartialSubgraph * antecedent)117     void AddDirectAntecedent(PartialSubgraph* antecedent)
118     {
119         if (m_Parent == nullptr)
120         {
121             antecedent = antecedent->GetRepresentative();
122 
123             m_Antecedents.insert(antecedent);
124             // Also record all of its antecedents, so that we end up with direct and indirect antecedents.
125             // This makes the lookup in HasAntecedent() faster.
126             m_Antecedents.insert(antecedent->m_Antecedents.begin(), antecedent->m_Antecedents.end());
127             // All of our dependents also need to include the new antecedents
128             for (PartialSubgraph* d : m_Dependants)
129             {
130                 d->m_Antecedents.insert(antecedent);
131                 d->m_Antecedents.insert(antecedent->m_Antecedents.begin(), antecedent->m_Antecedents.end());
132             }
133 
134             // Store reverse dependencies as well, required so that we can efficiently navigate the graph
135             // when making updates.
136             antecedent->m_Dependants.insert(this);
137             antecedent->m_Dependants.insert(m_Dependants.begin(), m_Dependants.end());
138             for (PartialSubgraph* a : antecedent->m_Antecedents)
139             {
140                 a->m_Dependants.insert(this);
141                 a->m_Dependants.insert(m_Dependants.begin(), m_Dependants.end());
142             }
143         }
144         else
145         {
146             // Defer request to the representative
147             GetRepresentative()->AddDirectAntecedent(antecedent);
148         }
149     }
150 
151     /// Checks if this subgraph is dependent on the given subgraph, either directly or indirectly.
HasAntecedent(PartialSubgraph * antecedent)152     bool HasAntecedent(PartialSubgraph* antecedent)
153     {
154         if (m_Parent == nullptr)
155         {
156             antecedent = antecedent->GetRepresentative();
157             // Thanks to keeping this set updated in MergeWith and AddDirectAntecedent, we can do an efficient lookup.
158             return m_Antecedents.count(antecedent) > 0;
159         }
160         else
161         {
162             // Defer request to the representative
163             return GetRepresentative()->HasAntecedent(antecedent);
164         }
165     }
166 
167 private:
168     /// Pointer to the parent node in the tree. If this is null then we are the representative for our merged subgraph.
169     PartialSubgraph* m_Parent;
170     /// The representatives of all the subgraphs which we depend on, either directly or indirectly.
171     std::unordered_set<PartialSubgraph*> m_Antecedents;
172     /// The representatives of all the subgraphs which depend on us, either directly or indirectly.
173     std::unordered_set<PartialSubgraph*> m_Dependants;
174 };
175 
176 /// Intermediate data structure to store information associated with a particular layer.
177 struct LayerSelectionInfo
178 {
179     using LayerInfoContainer = std::map<IConnectableLayer*, LayerSelectionInfo>;
180     using LayerInfoQueue = std::queue<LayerSelectionInfo*>;
181 
LayerSelectionInfoarmnn::__anon04db257d0111::LayerSelectionInfo182     LayerSelectionInfo(Layer* layer, const SubgraphViewSelector::LayerSelectorFunction& selector)
183     : m_Layer{layer}
184     , m_Subgraph{nullptr}
185     , m_IsSelected{selector(*layer)}
186     , m_IsProcessed(false)
187     {
188     }
189 
IsInputLayerarmnn::__anon04db257d0111::LayerSelectionInfo190     bool IsInputLayer() const
191     {
192         return m_Layer->GetType() == armnn::LayerType::Input || m_Layer->GetType() == armnn::LayerType::Constant;
193     }
194 
CollectNonSelectedInputsarmnn::__anon04db257d0111::LayerSelectionInfo195     void CollectNonSelectedInputs(LayerSelectionInfo::LayerInfoContainer& layerInfos,
196                                   SubgraphView::IInputSlots& inputSlots)
197     {
198         for (auto&& slot = PolymorphicDowncast<Layer*>(m_Layer)->BeginInputSlots();
199              slot != PolymorphicDowncast<Layer*>(m_Layer)->EndInputSlots();
200              ++slot)
201         {
202             OutputSlot* parentLayerOutputSlot = slot->GetConnectedOutputSlot();
203             ARMNN_ASSERT_MSG(parentLayerOutputSlot != nullptr, "The input slots must be connected here.");
204             if (parentLayerOutputSlot)
205             {
206                 Layer& parentLayer = parentLayerOutputSlot->GetOwningLayer();
207                 auto parentInfo = layerInfos.find(&parentLayer);
208                 if (parentInfo == layerInfos.end() ||
209                         !m_Subgraph->IsMergedWith(parentInfo->second.m_Subgraph.get()))
210                 {
211                     // Avoid collecting duplicate input slots
212                     InputSlot* inputSlot = &(*slot);
213                     if (std::find(inputSlots.begin(), inputSlots.end(), inputSlot) == inputSlots.end())
214                     {
215                         inputSlots.push_back(inputSlot);
216                     }
217                 }
218             }
219         }
220     }
221 
CollectNonSelectedOutputSlotsarmnn::__anon04db257d0111::LayerSelectionInfo222     void CollectNonSelectedOutputSlots(LayerSelectionInfo::LayerInfoContainer& layerInfos,
223                                        SubgraphView::IOutputSlots& outputSlots)
224     {
225         for (auto&& slot = PolymorphicDowncast<Layer*>(m_Layer)->BeginOutputSlots();
226              slot != PolymorphicDowncast<Layer*>(m_Layer)->EndOutputSlots();
227              ++slot)
228         {
229             for (InputSlot* childLayerInputSlot : slot->GetConnections())
230             {
231                 Layer& childLayer = childLayerInputSlot->GetOwningLayer();
232                 auto childInfo = layerInfos.find(&childLayer);
233                 if (childInfo == layerInfos.end() ||
234                         !m_Subgraph->IsMergedWith(childInfo->second.m_Subgraph.get()))
235                 {
236                     // Avoid collecting duplicate output slots
237                     OutputSlot* outputSlot = &(*slot);
238                     if (std::find(outputSlots.begin(), outputSlots.end(), outputSlot) == outputSlots.end())
239                     {
240                         outputSlots.push_back(outputSlot);
241                     }
242                 }
243             }
244         }
245     }
246 
247     IConnectableLayer* m_Layer;
248     /// Which subgraph this layer has been assigned to. Only valid once m_IsProcessed is true.
249     /// Two layers with different m_Subgraph pointers may in fact have been merged into the same subgraph -
250     /// see the description of the PartialSubgraph class.
251     std::shared_ptr<PartialSubgraph> m_Subgraph;
252     bool m_IsSelected;
253     bool m_IsProcessed;
254 };
255 
256 } // namespace <anonymous>
257 
258 SubgraphViewSelector::Subgraphs
SelectSubgraphs(Graph & graph,const LayerSelectorFunction & selector)259 SubgraphViewSelector::SelectSubgraphs(Graph& graph, const LayerSelectorFunction& selector)
260 {
261     SubgraphView subgraph(graph);
262     return SubgraphViewSelector::SelectSubgraphs(subgraph, selector);
263 }
264 
265 
266 template<typename Delegate>
ForEachLayerInput(LayerSelectionInfo::LayerInfoContainer & layerInfos,LayerSelectionInfo & layerInfo,Delegate function)267 void ForEachLayerInput(LayerSelectionInfo::LayerInfoContainer& layerInfos,
268                        LayerSelectionInfo& layerInfo,
269                        Delegate function)
270 {
271     Layer& layer = *PolymorphicDowncast<Layer*>(layerInfo.m_Layer);
272 
273     for (auto inputSlot : layer.GetInputSlots())
274     {
275         auto connectedInput = PolymorphicDowncast<OutputSlot*>(inputSlot.GetConnection());
276         ARMNN_ASSERT_MSG(connectedInput, "Dangling input slot detected.");
277         Layer& inputLayer = connectedInput->GetOwningLayer();
278 
279         auto parentInfo = layerInfos.find(&inputLayer);
280         if (parentInfo != layerInfos.end())
281         {
282             function(parentInfo->second);
283         }
284     }
285 }
286 
287 template<typename Delegate>
ForEachLayerOutput(LayerSelectionInfo::LayerInfoContainer & layerInfos,LayerSelectionInfo & layerInfo,Delegate function)288 void ForEachLayerOutput(LayerSelectionInfo::LayerInfoContainer& layerInfos,
289                         LayerSelectionInfo& layerInfo,
290                         Delegate function)
291 {
292     Layer& layer = *PolymorphicDowncast<Layer*>(layerInfo.m_Layer);
293 
294     for (auto& outputSlot : layer.GetOutputSlots())
295     {
296         for (auto& output : outputSlot.GetConnections())
297         {
298             Layer& childLayer = output->GetOwningLayer();
299 
300             auto childInfo = layerInfos.find(&childLayer);
301             if (childInfo != layerInfos.end())
302             {
303                 function(childInfo->second);
304             }
305         }
306     }
307 }
308 
AssignSplitId(LayerSelectionInfo::LayerInfoContainer & layerInfos,LayerSelectionInfo & layerInfo)309 void AssignSplitId(LayerSelectionInfo::LayerInfoContainer& layerInfos, LayerSelectionInfo& layerInfo)
310 {
311     // Check each input to see if we can attach ourselves to any of the subgraphs that have already been assigned.
312     ForEachLayerInput(layerInfos, layerInfo, [&](LayerSelectionInfo& parentInfo)
313     {
314         // We can only attach ourselves to the subgraph from this input if there isn't a cut here.
315         if (layerInfo.m_IsSelected == parentInfo.m_IsSelected)
316         {
317             // We also need to check that merging into this subgraph won't cause a dependency cycle between subgraphs.
318             // This will be the case if the subgraph that we will become part of is already a dependency
319             // of one of the subgraphs that are input to this layer, e.g:
320             //
321             //    0     |  The numbers (0, 1) are the subgraph IDs of each layer and we are looking at layer X.
322             //   / \    |
323             //  1   0   |  We can't merge X into subgraph 0, because the left-hand input already depends on subgraph 0.
324             //   \ /    |  We can however merge X into subgraph 1.
325             //    X     |
326             //
327             bool dependenciesOk = true;
328             ForEachLayerInput(layerInfos, layerInfo, [&](LayerSelectionInfo& otherParentInfo)
329             {
330                 // We call HasAntecedent() ~ n^2 times, where n is the number of inputs to this layer.
331                 // Hence it is important that this is efficient - see PartialSubgraph class description.
332                 if (otherParentInfo.m_Subgraph->HasAntecedent(parentInfo.m_Subgraph.get()))
333                 {
334                     dependenciesOk = false;
335                 }
336             });
337 
338             if (dependenciesOk)
339             {
340                 // Merge into the subgraph of this input. If we have already been merged into another subgraph
341                 // (from another input of this layer), then merge both of them together.
342                 if (layerInfo.m_Subgraph == nullptr)
343                 {
344                     layerInfo.m_Subgraph = parentInfo.m_Subgraph;
345                 }
346                 else
347                 {
348                     // We call MergeWith() ~ n times, where n is the number of inputs to this layer.
349                     // Therefore it does not need to be as performant as HasAntecedent().
350                     layerInfo.m_Subgraph->MergeWith(parentInfo.m_Subgraph.get());
351                 }
352             }
353         }
354     });
355 
356     // If we weren't able to merge into an existing subgraph then we need to make a new one
357     if (layerInfo.m_Subgraph == nullptr)
358     {
359         layerInfo.m_Subgraph = std::make_shared<PartialSubgraph>();
360     }
361 
362     // Record dependencies of the chosen subgraph based on the inputs of this layer.
363     ForEachLayerInput(layerInfos, layerInfo, [&](LayerSelectionInfo& parentInfo)
364     {
365         // These functions are called ~n times, where n is the number of inputs to this layer.
366         // Therefore it does not need to be as performant as HasAntecedent().
367         if (!layerInfo.m_Subgraph->IsMergedWith(parentInfo.m_Subgraph.get()))
368         {
369             layerInfo.m_Subgraph->AddDirectAntecedent(parentInfo.m_Subgraph.get());
370         }
371     });
372 }
373 
IsReadyForSplitAssignment(LayerSelectionInfo::LayerInfoContainer & layerInfos,LayerSelectionInfo & layerInfo)374 bool IsReadyForSplitAssignment(LayerSelectionInfo::LayerInfoContainer& layerInfos, LayerSelectionInfo& layerInfo)
375 {
376     bool ready = true;
377     ForEachLayerInput(layerInfos, layerInfo,
378                       [&ready](LayerSelectionInfo& parentInfo)
379                           {
380                               if (!parentInfo.m_IsProcessed)
381                               {
382                                   ready = false;
383                               }
384                           });
385     return ready;
386 }
387 
388 SubgraphViewSelector::Subgraphs
SelectSubgraphs(SubgraphView & subgraph,const LayerSelectorFunction & selector)389 SubgraphViewSelector::SelectSubgraphs(SubgraphView& subgraph, const LayerSelectorFunction& selector)
390 {
391     LayerSelectionInfo::LayerInfoContainer layerInfos;
392 
393     LayerSelectionInfo::LayerInfoQueue processQueue;
394     const SubgraphView::IConnectableLayers& subgraphLayers = subgraph.GetIConnectableLayers();
395     for (auto& layer : subgraphLayers)
396     {
397 
398         auto emplaced = layerInfos.emplace(layer, LayerSelectionInfo{PolymorphicDowncast<Layer*>(layer), selector});
399         LayerSelectionInfo& layerInfo = emplaced.first->second;
400 
401         // Start with Input type layers
402         if (layerInfo.IsInputLayer())
403         {
404             processQueue.push(&layerInfo);
405         }
406     }
407 
408     const SubgraphView::IInputSlots& subgraphInputSlots = subgraph.GetIInputSlots();
409     for (auto& inputSlot : subgraphInputSlots)
410     {
411         Layer& layer = PolymorphicDowncast<InputSlot*>(inputSlot)->GetOwningLayer();
412         auto emplaced = layerInfos.emplace(&layer, LayerSelectionInfo{&layer, selector});
413         LayerSelectionInfo& layerInfo = emplaced.first->second;
414 
415         processQueue.push(&layerInfo);
416     }
417 
418     while (!processQueue.empty())
419     {
420         LayerSelectionInfo& layerInfo = *processQueue.front();
421         processQueue.pop(); // remove front from queue
422 
423         // This layerInfo may have been added to the queue multiple times, so skip if we have already processed it
424         if (!layerInfo.m_IsProcessed)
425         {
426             // Only process this layerInfo if all inputs have been processed
427             if (!IsReadyForSplitAssignment(layerInfos, layerInfo))
428             {
429                 // Put back of the process queue if we can't process it just yet
430                 processQueue.push(&layerInfo);
431                 continue; // Skip to next iteration
432             }
433 
434             // Now we do the processing
435             AssignSplitId(layerInfos, layerInfo);
436 
437             // Queue any child nodes for processing
438             ForEachLayerOutput(layerInfos, layerInfo, [&processQueue](LayerSelectionInfo& childInfo)
439                 {
440                     processQueue.push(&childInfo);
441                 });
442 
443             // We don't need to process this node again
444             layerInfo.m_IsProcessed = true;
445         }
446     }
447 
448     // Collect all selected layers keyed by subgraph representative into a map
449     using SelectionInfoPtrs = std::vector<LayerSelectionInfo*>;
450     std::map<PartialSubgraph*, SelectionInfoPtrs> splitMap;
451     for (auto& info : layerInfos)
452     {
453         if (info.second.m_IsSelected)
454         {
455             auto it = splitMap.find(info.second.m_Subgraph->GetRepresentative());
456             if (it == splitMap.end())
457             {
458                 splitMap.insert(
459                     std::make_pair(info.second.m_Subgraph->GetRepresentative(), SelectionInfoPtrs{&info.second}));
460             }
461             else
462             {
463                 it->second.push_back(&info.second);
464             }
465         }
466     }
467 
468     // Now each entry in splitMap represents a subgraph
469     Subgraphs result;
470     for (auto& splitGraph : splitMap)
471     {
472         SubgraphView::IInputSlots inputs;
473         SubgraphView::IOutputSlots outputs;
474         SubgraphView::IConnectableLayers layers;
475         for (auto&& infoPtr : splitGraph.second)
476         {
477             infoPtr->CollectNonSelectedInputs(layerInfos, inputs);
478             infoPtr->CollectNonSelectedOutputSlots(layerInfos, outputs);
479             layers.push_back(infoPtr->m_Layer);
480         }
481 
482         // Sort lists into deterministic order, not relying on pointer values which may be different on each execution.
483         // This makes debugging the optimised graph much easier as subsequent stages can also be deterministic.
484         std::sort(inputs.begin(), inputs.end(), [](const IInputSlot* a, const IInputSlot* b)
485         {
486             auto* castA = PolymorphicDowncast<const InputSlot*>(a);
487             auto* castB = PolymorphicDowncast<const InputSlot*>(b);
488             const LayerGuid guidA = castA->GetOwningLayer().GetGuid();
489             const LayerGuid guidB = castB->GetOwningLayer().GetGuid();
490             if (guidA < guidB)
491             {
492                 return true;
493             }
494             else if (guidA == guidB)
495             {
496                 return (castA->GetSlotIndex() < castB->GetSlotIndex());
497             }
498             return false;
499         });
500         std::sort(outputs.begin(), outputs.end(), [](const IOutputSlot* a, const IOutputSlot* b)
501         {
502             auto* castA = PolymorphicDowncast<const OutputSlot*>(a);
503             auto* castB = PolymorphicDowncast<const OutputSlot*>(b);
504             const LayerGuid guidA = castA->GetOwningLayer().GetGuid();
505             const LayerGuid guidB = castB->GetOwningLayer().GetGuid();
506             if (guidA < guidB)
507             {
508                 return true;
509             }
510             else if (guidA == guidB)
511             {
512                 return (a->CalculateIndexOnOwner() < b->CalculateIndexOnOwner());
513             }
514             return false;
515         });
516         layers.sort([](const IConnectableLayer* a, const IConnectableLayer* b) { return a->GetGuid() < b->GetGuid(); });
517 
518         // Create a new sub-graph with the new lists of input/output slots and layer
519         result.emplace_back(std::make_unique<SubgraphView>(std::move(layers),
520                                                            std::move(inputs),
521                                                            std::move(outputs)));
522     }
523 
524     // Sort subgraphs list into deterministic order, not relying on pointer values which may be different on each
525     // execution. This makes debugging the optimised graph much easier as subsequent stages can also be
526     // deterministic.
527     std::sort(result.begin(), result.end(), [](const SubgraphView::SubgraphViewPtr& a,
528                                                const SubgraphView::SubgraphViewPtr& b)
529     {
530         return a->GetIConnectableLayers().front()->GetGuid() < b->GetIConnectableLayers().front()->GetGuid();
531     });
532 
533     return result;
534 }
535 
536 } // namespace armnn
537