1 //
2 // Copyright © 2017, 2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "SubgraphViewSelector.hpp"
7 #include "Graph.hpp"
8
9 #include <armnn/utility/Assert.hpp>
10 #include <armnn/utility/IgnoreUnused.hpp>
11 #include <armnn/utility/PolymorphicDowncast.hpp>
12
13 #include <algorithm>
14 #include <map>
15 #include <queue>
16 #include <unordered_set>
17
18 namespace armnn
19 {
20
21 namespace
22 {
23
24 /// Intermediate data-structure to store the subgraph that a layer has been assigned to.
25 /// This is a "disjoint set" data structure that allows efficient merging of subgraphs,
26 /// which is a key part of the algorithm. Subgraphs are arranged in singly-linked trees
27 /// (with each node storing a pointer to its parent). Subgraphs in the same tree are considered
28 /// to have been merged. Merging subgraphs is performed by attaching one tree to another,
29 /// which is a simple pointer update.
30 ///
31 /// NOTE: Due to the way this is stored, it is almost never correct to directly compare pointers
32 /// to two PartialSubgraphs to check if two layers belong in the same subgraph. Instead you
33 /// should use IsMergedWith().
34 ///
35 /// This structure also stores information about the dependencies of each subgraph, which is needed
36 /// to determine whether certain subgraphs can be merged. Checking whether a subgraph
37 /// depends on another subgraph is a frequent operation in the algorithm (see AssignSplitId) and so this is optimized
38 /// in preference to the merging of subgraphs. This leads to an approach where each subgraph stores
39 /// a set of all the subgraphs it depends on (for a fast lookup). In order to efficiently update this
40 /// set as subgraphs are merged means we also store a set of subgraphs which *depend on us* (i.e. the
41 /// complement of our dependencies).
42 class PartialSubgraph
43 {
44 public:
45 /// If this subgraph has been merged with another then there is an agreed "representative" for the combined
46 /// subgraph, which uniquely identifies the subgraph.
GetRepresentative()47 PartialSubgraph* GetRepresentative()
48 {
49 // Recurse up the tree to find the root node.
50 if (m_Parent == nullptr)
51 {
52 return this;
53 }
54 else
55 {
56 PartialSubgraph* result = m_Parent->GetRepresentative();
57 // Update our parent pointer to point directly to the root in order to speed up future calls to this method.
58 // This essentially "flattens" the tree.
59 m_Parent = result;
60 return result;
61 }
62 }
63
64 /// Merges this subgraph with another.
MergeWith(PartialSubgraph * other)65 void MergeWith(PartialSubgraph* other)
66 {
67 if (m_Parent == nullptr)
68 {
69 other = other->GetRepresentative();
70 if (this == other)
71 {
72 // Already merged - no-op
73 return;
74 }
75 m_Parent = other;
76
77 // Update others' dependency sets to point to the new representative rather than us.
78 // Keeping these up-to-date means we can rely on these sets containing representatives when
79 // we perform a lookup in HasAntecedent() and so don't need to resolve the representative for each element
80 // of the set. See description at the top of this class for more rationale.
81 for (PartialSubgraph* a : m_Antecedents)
82 {
83 size_t numErased = a->m_Dependants.erase(this);
84 ARMNN_ASSERT(numErased == 1);
85 IgnoreUnused(numErased);
86 a->m_Dependants.insert(m_Parent);
87 }
88 for (PartialSubgraph* a : m_Dependants)
89 {
90 size_t numErased = a->m_Antecedents.erase(this);
91 ARMNN_ASSERT(numErased == 1);
92 IgnoreUnused(numErased);
93 a->m_Antecedents.insert(m_Parent);
94 }
95
96 // Merge our dependency sets into our new representative.
97 // We no longer need to maintain our own sets, as requests will always be forwarded to the representative.
98 m_Parent->m_Antecedents.insert(m_Antecedents.begin(), m_Antecedents.end());
99 m_Antecedents.clear();
100 m_Parent->m_Dependants.insert(m_Dependants.begin(), m_Dependants.end());
101 m_Dependants.clear();
102 }
103 else
104 {
105 // Defer request to the representative
106 GetRepresentative()->MergeWith(other);
107 }
108 }
109
110 /// Checks if this subgraph has been merged with the given subgraph.
IsMergedWith(PartialSubgraph * other)111 bool IsMergedWith(PartialSubgraph* other)
112 {
113 return GetRepresentative() == other->GetRepresentative();
114 }
115
116 /// Marks the given subgraph as a direct antecedent (dependency) of this one.
AddDirectAntecedent(PartialSubgraph * antecedent)117 void AddDirectAntecedent(PartialSubgraph* antecedent)
118 {
119 if (m_Parent == nullptr)
120 {
121 antecedent = antecedent->GetRepresentative();
122
123 m_Antecedents.insert(antecedent);
124 // Also record all of its antecedents, so that we end up with direct and indirect antecedents.
125 // This makes the lookup in HasAntecedent() faster.
126 m_Antecedents.insert(antecedent->m_Antecedents.begin(), antecedent->m_Antecedents.end());
127 // All of our dependents also need to include the new antecedents
128 for (PartialSubgraph* d : m_Dependants)
129 {
130 d->m_Antecedents.insert(antecedent);
131 d->m_Antecedents.insert(antecedent->m_Antecedents.begin(), antecedent->m_Antecedents.end());
132 }
133
134 // Store reverse dependencies as well, required so that we can efficiently navigate the graph
135 // when making updates.
136 antecedent->m_Dependants.insert(this);
137 antecedent->m_Dependants.insert(m_Dependants.begin(), m_Dependants.end());
138 for (PartialSubgraph* a : antecedent->m_Antecedents)
139 {
140 a->m_Dependants.insert(this);
141 a->m_Dependants.insert(m_Dependants.begin(), m_Dependants.end());
142 }
143 }
144 else
145 {
146 // Defer request to the representative
147 GetRepresentative()->AddDirectAntecedent(antecedent);
148 }
149 }
150
151 /// Checks if this subgraph is dependent on the given subgraph, either directly or indirectly.
HasAntecedent(PartialSubgraph * antecedent)152 bool HasAntecedent(PartialSubgraph* antecedent)
153 {
154 if (m_Parent == nullptr)
155 {
156 antecedent = antecedent->GetRepresentative();
157 // Thanks to keeping this set updated in MergeWith and AddDirectAntecedent, we can do an efficient lookup.
158 return m_Antecedents.count(antecedent) > 0;
159 }
160 else
161 {
162 // Defer request to the representative
163 return GetRepresentative()->HasAntecedent(antecedent);
164 }
165 }
166
167 private:
168 /// Pointer to the parent node in the tree. If this is null then we are the representative for our merged subgraph.
169 PartialSubgraph* m_Parent;
170 /// The representatives of all the subgraphs which we depend on, either directly or indirectly.
171 std::unordered_set<PartialSubgraph*> m_Antecedents;
172 /// The representatives of all the subgraphs which depend on us, either directly or indirectly.
173 std::unordered_set<PartialSubgraph*> m_Dependants;
174 };
175
176 /// Intermediate data structure to store information associated with a particular layer.
177 struct LayerSelectionInfo
178 {
179 using LayerInfoContainer = std::map<IConnectableLayer*, LayerSelectionInfo>;
180 using LayerInfoQueue = std::queue<LayerSelectionInfo*>;
181
LayerSelectionInfoarmnn::__anon04db257d0111::LayerSelectionInfo182 LayerSelectionInfo(Layer* layer, const SubgraphViewSelector::LayerSelectorFunction& selector)
183 : m_Layer{layer}
184 , m_Subgraph{nullptr}
185 , m_IsSelected{selector(*layer)}
186 , m_IsProcessed(false)
187 {
188 }
189
IsInputLayerarmnn::__anon04db257d0111::LayerSelectionInfo190 bool IsInputLayer() const
191 {
192 return m_Layer->GetType() == armnn::LayerType::Input || m_Layer->GetType() == armnn::LayerType::Constant;
193 }
194
CollectNonSelectedInputsarmnn::__anon04db257d0111::LayerSelectionInfo195 void CollectNonSelectedInputs(LayerSelectionInfo::LayerInfoContainer& layerInfos,
196 SubgraphView::IInputSlots& inputSlots)
197 {
198 for (auto&& slot = PolymorphicDowncast<Layer*>(m_Layer)->BeginInputSlots();
199 slot != PolymorphicDowncast<Layer*>(m_Layer)->EndInputSlots();
200 ++slot)
201 {
202 OutputSlot* parentLayerOutputSlot = slot->GetConnectedOutputSlot();
203 ARMNN_ASSERT_MSG(parentLayerOutputSlot != nullptr, "The input slots must be connected here.");
204 if (parentLayerOutputSlot)
205 {
206 Layer& parentLayer = parentLayerOutputSlot->GetOwningLayer();
207 auto parentInfo = layerInfos.find(&parentLayer);
208 if (parentInfo == layerInfos.end() ||
209 !m_Subgraph->IsMergedWith(parentInfo->second.m_Subgraph.get()))
210 {
211 // Avoid collecting duplicate input slots
212 InputSlot* inputSlot = &(*slot);
213 if (std::find(inputSlots.begin(), inputSlots.end(), inputSlot) == inputSlots.end())
214 {
215 inputSlots.push_back(inputSlot);
216 }
217 }
218 }
219 }
220 }
221
CollectNonSelectedOutputSlotsarmnn::__anon04db257d0111::LayerSelectionInfo222 void CollectNonSelectedOutputSlots(LayerSelectionInfo::LayerInfoContainer& layerInfos,
223 SubgraphView::IOutputSlots& outputSlots)
224 {
225 for (auto&& slot = PolymorphicDowncast<Layer*>(m_Layer)->BeginOutputSlots();
226 slot != PolymorphicDowncast<Layer*>(m_Layer)->EndOutputSlots();
227 ++slot)
228 {
229 for (InputSlot* childLayerInputSlot : slot->GetConnections())
230 {
231 Layer& childLayer = childLayerInputSlot->GetOwningLayer();
232 auto childInfo = layerInfos.find(&childLayer);
233 if (childInfo == layerInfos.end() ||
234 !m_Subgraph->IsMergedWith(childInfo->second.m_Subgraph.get()))
235 {
236 // Avoid collecting duplicate output slots
237 OutputSlot* outputSlot = &(*slot);
238 if (std::find(outputSlots.begin(), outputSlots.end(), outputSlot) == outputSlots.end())
239 {
240 outputSlots.push_back(outputSlot);
241 }
242 }
243 }
244 }
245 }
246
247 IConnectableLayer* m_Layer;
248 /// Which subgraph this layer has been assigned to. Only valid once m_IsProcessed is true.
249 /// Two layers with different m_Subgraph pointers may in fact have been merged into the same subgraph -
250 /// see the description of the PartialSubgraph class.
251 std::shared_ptr<PartialSubgraph> m_Subgraph;
252 bool m_IsSelected;
253 bool m_IsProcessed;
254 };
255
256 } // namespace <anonymous>
257
258 SubgraphViewSelector::Subgraphs
SelectSubgraphs(Graph & graph,const LayerSelectorFunction & selector)259 SubgraphViewSelector::SelectSubgraphs(Graph& graph, const LayerSelectorFunction& selector)
260 {
261 SubgraphView subgraph(graph);
262 return SubgraphViewSelector::SelectSubgraphs(subgraph, selector);
263 }
264
265
266 template<typename Delegate>
ForEachLayerInput(LayerSelectionInfo::LayerInfoContainer & layerInfos,LayerSelectionInfo & layerInfo,Delegate function)267 void ForEachLayerInput(LayerSelectionInfo::LayerInfoContainer& layerInfos,
268 LayerSelectionInfo& layerInfo,
269 Delegate function)
270 {
271 Layer& layer = *PolymorphicDowncast<Layer*>(layerInfo.m_Layer);
272
273 for (auto inputSlot : layer.GetInputSlots())
274 {
275 auto connectedInput = PolymorphicDowncast<OutputSlot*>(inputSlot.GetConnection());
276 ARMNN_ASSERT_MSG(connectedInput, "Dangling input slot detected.");
277 Layer& inputLayer = connectedInput->GetOwningLayer();
278
279 auto parentInfo = layerInfos.find(&inputLayer);
280 if (parentInfo != layerInfos.end())
281 {
282 function(parentInfo->second);
283 }
284 }
285 }
286
287 template<typename Delegate>
ForEachLayerOutput(LayerSelectionInfo::LayerInfoContainer & layerInfos,LayerSelectionInfo & layerInfo,Delegate function)288 void ForEachLayerOutput(LayerSelectionInfo::LayerInfoContainer& layerInfos,
289 LayerSelectionInfo& layerInfo,
290 Delegate function)
291 {
292 Layer& layer = *PolymorphicDowncast<Layer*>(layerInfo.m_Layer);
293
294 for (auto& outputSlot : layer.GetOutputSlots())
295 {
296 for (auto& output : outputSlot.GetConnections())
297 {
298 Layer& childLayer = output->GetOwningLayer();
299
300 auto childInfo = layerInfos.find(&childLayer);
301 if (childInfo != layerInfos.end())
302 {
303 function(childInfo->second);
304 }
305 }
306 }
307 }
308
AssignSplitId(LayerSelectionInfo::LayerInfoContainer & layerInfos,LayerSelectionInfo & layerInfo)309 void AssignSplitId(LayerSelectionInfo::LayerInfoContainer& layerInfos, LayerSelectionInfo& layerInfo)
310 {
311 // Check each input to see if we can attach ourselves to any of the subgraphs that have already been assigned.
312 ForEachLayerInput(layerInfos, layerInfo, [&](LayerSelectionInfo& parentInfo)
313 {
314 // We can only attach ourselves to the subgraph from this input if there isn't a cut here.
315 if (layerInfo.m_IsSelected == parentInfo.m_IsSelected)
316 {
317 // We also need to check that merging into this subgraph won't cause a dependency cycle between subgraphs.
318 // This will be the case if the subgraph that we will become part of is already a dependency
319 // of one of the subgraphs that are input to this layer, e.g:
320 //
321 // 0 | The numbers (0, 1) are the subgraph IDs of each layer and we are looking at layer X.
322 // / \ |
323 // 1 0 | We can't merge X into subgraph 0, because the left-hand input already depends on subgraph 0.
324 // \ / | We can however merge X into subgraph 1.
325 // X |
326 //
327 bool dependenciesOk = true;
328 ForEachLayerInput(layerInfos, layerInfo, [&](LayerSelectionInfo& otherParentInfo)
329 {
330 // We call HasAntecedent() ~ n^2 times, where n is the number of inputs to this layer.
331 // Hence it is important that this is efficient - see PartialSubgraph class description.
332 if (otherParentInfo.m_Subgraph->HasAntecedent(parentInfo.m_Subgraph.get()))
333 {
334 dependenciesOk = false;
335 }
336 });
337
338 if (dependenciesOk)
339 {
340 // Merge into the subgraph of this input. If we have already been merged into another subgraph
341 // (from another input of this layer), then merge both of them together.
342 if (layerInfo.m_Subgraph == nullptr)
343 {
344 layerInfo.m_Subgraph = parentInfo.m_Subgraph;
345 }
346 else
347 {
348 // We call MergeWith() ~ n times, where n is the number of inputs to this layer.
349 // Therefore it does not need to be as performant as HasAntecedent().
350 layerInfo.m_Subgraph->MergeWith(parentInfo.m_Subgraph.get());
351 }
352 }
353 }
354 });
355
356 // If we weren't able to merge into an existing subgraph then we need to make a new one
357 if (layerInfo.m_Subgraph == nullptr)
358 {
359 layerInfo.m_Subgraph = std::make_shared<PartialSubgraph>();
360 }
361
362 // Record dependencies of the chosen subgraph based on the inputs of this layer.
363 ForEachLayerInput(layerInfos, layerInfo, [&](LayerSelectionInfo& parentInfo)
364 {
365 // These functions are called ~n times, where n is the number of inputs to this layer.
366 // Therefore it does not need to be as performant as HasAntecedent().
367 if (!layerInfo.m_Subgraph->IsMergedWith(parentInfo.m_Subgraph.get()))
368 {
369 layerInfo.m_Subgraph->AddDirectAntecedent(parentInfo.m_Subgraph.get());
370 }
371 });
372 }
373
IsReadyForSplitAssignment(LayerSelectionInfo::LayerInfoContainer & layerInfos,LayerSelectionInfo & layerInfo)374 bool IsReadyForSplitAssignment(LayerSelectionInfo::LayerInfoContainer& layerInfos, LayerSelectionInfo& layerInfo)
375 {
376 bool ready = true;
377 ForEachLayerInput(layerInfos, layerInfo,
378 [&ready](LayerSelectionInfo& parentInfo)
379 {
380 if (!parentInfo.m_IsProcessed)
381 {
382 ready = false;
383 }
384 });
385 return ready;
386 }
387
388 SubgraphViewSelector::Subgraphs
SelectSubgraphs(SubgraphView & subgraph,const LayerSelectorFunction & selector)389 SubgraphViewSelector::SelectSubgraphs(SubgraphView& subgraph, const LayerSelectorFunction& selector)
390 {
391 LayerSelectionInfo::LayerInfoContainer layerInfos;
392
393 LayerSelectionInfo::LayerInfoQueue processQueue;
394 const SubgraphView::IConnectableLayers& subgraphLayers = subgraph.GetIConnectableLayers();
395 for (auto& layer : subgraphLayers)
396 {
397
398 auto emplaced = layerInfos.emplace(layer, LayerSelectionInfo{PolymorphicDowncast<Layer*>(layer), selector});
399 LayerSelectionInfo& layerInfo = emplaced.first->second;
400
401 // Start with Input type layers
402 if (layerInfo.IsInputLayer())
403 {
404 processQueue.push(&layerInfo);
405 }
406 }
407
408 const SubgraphView::IInputSlots& subgraphInputSlots = subgraph.GetIInputSlots();
409 for (auto& inputSlot : subgraphInputSlots)
410 {
411 Layer& layer = PolymorphicDowncast<InputSlot*>(inputSlot)->GetOwningLayer();
412 auto emplaced = layerInfos.emplace(&layer, LayerSelectionInfo{&layer, selector});
413 LayerSelectionInfo& layerInfo = emplaced.first->second;
414
415 processQueue.push(&layerInfo);
416 }
417
418 while (!processQueue.empty())
419 {
420 LayerSelectionInfo& layerInfo = *processQueue.front();
421 processQueue.pop(); // remove front from queue
422
423 // This layerInfo may have been added to the queue multiple times, so skip if we have already processed it
424 if (!layerInfo.m_IsProcessed)
425 {
426 // Only process this layerInfo if all inputs have been processed
427 if (!IsReadyForSplitAssignment(layerInfos, layerInfo))
428 {
429 // Put back of the process queue if we can't process it just yet
430 processQueue.push(&layerInfo);
431 continue; // Skip to next iteration
432 }
433
434 // Now we do the processing
435 AssignSplitId(layerInfos, layerInfo);
436
437 // Queue any child nodes for processing
438 ForEachLayerOutput(layerInfos, layerInfo, [&processQueue](LayerSelectionInfo& childInfo)
439 {
440 processQueue.push(&childInfo);
441 });
442
443 // We don't need to process this node again
444 layerInfo.m_IsProcessed = true;
445 }
446 }
447
448 // Collect all selected layers keyed by subgraph representative into a map
449 using SelectionInfoPtrs = std::vector<LayerSelectionInfo*>;
450 std::map<PartialSubgraph*, SelectionInfoPtrs> splitMap;
451 for (auto& info : layerInfos)
452 {
453 if (info.second.m_IsSelected)
454 {
455 auto it = splitMap.find(info.second.m_Subgraph->GetRepresentative());
456 if (it == splitMap.end())
457 {
458 splitMap.insert(
459 std::make_pair(info.second.m_Subgraph->GetRepresentative(), SelectionInfoPtrs{&info.second}));
460 }
461 else
462 {
463 it->second.push_back(&info.second);
464 }
465 }
466 }
467
468 // Now each entry in splitMap represents a subgraph
469 Subgraphs result;
470 for (auto& splitGraph : splitMap)
471 {
472 SubgraphView::IInputSlots inputs;
473 SubgraphView::IOutputSlots outputs;
474 SubgraphView::IConnectableLayers layers;
475 for (auto&& infoPtr : splitGraph.second)
476 {
477 infoPtr->CollectNonSelectedInputs(layerInfos, inputs);
478 infoPtr->CollectNonSelectedOutputSlots(layerInfos, outputs);
479 layers.push_back(infoPtr->m_Layer);
480 }
481
482 // Sort lists into deterministic order, not relying on pointer values which may be different on each execution.
483 // This makes debugging the optimised graph much easier as subsequent stages can also be deterministic.
484 std::sort(inputs.begin(), inputs.end(), [](const IInputSlot* a, const IInputSlot* b)
485 {
486 auto* castA = PolymorphicDowncast<const InputSlot*>(a);
487 auto* castB = PolymorphicDowncast<const InputSlot*>(b);
488 const LayerGuid guidA = castA->GetOwningLayer().GetGuid();
489 const LayerGuid guidB = castB->GetOwningLayer().GetGuid();
490 if (guidA < guidB)
491 {
492 return true;
493 }
494 else if (guidA == guidB)
495 {
496 return (castA->GetSlotIndex() < castB->GetSlotIndex());
497 }
498 return false;
499 });
500 std::sort(outputs.begin(), outputs.end(), [](const IOutputSlot* a, const IOutputSlot* b)
501 {
502 auto* castA = PolymorphicDowncast<const OutputSlot*>(a);
503 auto* castB = PolymorphicDowncast<const OutputSlot*>(b);
504 const LayerGuid guidA = castA->GetOwningLayer().GetGuid();
505 const LayerGuid guidB = castB->GetOwningLayer().GetGuid();
506 if (guidA < guidB)
507 {
508 return true;
509 }
510 else if (guidA == guidB)
511 {
512 return (a->CalculateIndexOnOwner() < b->CalculateIndexOnOwner());
513 }
514 return false;
515 });
516 layers.sort([](const IConnectableLayer* a, const IConnectableLayer* b) { return a->GetGuid() < b->GetGuid(); });
517
518 // Create a new sub-graph with the new lists of input/output slots and layer
519 result.emplace_back(std::make_unique<SubgraphView>(std::move(layers),
520 std::move(inputs),
521 std::move(outputs)));
522 }
523
524 // Sort subgraphs list into deterministic order, not relying on pointer values which may be different on each
525 // execution. This makes debugging the optimised graph much easier as subsequent stages can also be
526 // deterministic.
527 std::sort(result.begin(), result.end(), [](const SubgraphView::SubgraphViewPtr& a,
528 const SubgraphView::SubgraphViewPtr& b)
529 {
530 return a->GetIConnectableLayers().front()->GetGuid() < b->GetIConnectableLayers().front()->GetGuid();
531 });
532
533 return result;
534 }
535
536 } // namespace armnn
537