xref: /aosp_15_r20/external/armnn/src/armnnUtils/GraphTopologicalSort.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include <armnn/Optional.hpp>
8 
9 #include <functional>
10 #include <map>
11 #include <stack>
12 #include <vector>
13 
14 
15 namespace armnnUtils
16 {
17 
18 namespace
19 {
20 
21 enum class NodeState
22 {
23     Visiting,
24     Visited,
25 };
26 
27 
28 template <typename TNodeId>
GetNextChild(TNodeId node,std::function<std::vector<TNodeId> (TNodeId)> getIncomingEdges,std::map<TNodeId,NodeState> & nodeStates)29 armnn::Optional<TNodeId> GetNextChild(TNodeId node,
30                                       std::function<std::vector<TNodeId>(TNodeId)> getIncomingEdges,
31                                       std::map<TNodeId, NodeState>& nodeStates)
32 {
33     for (TNodeId childNode : getIncomingEdges(node))
34     {
35         if (nodeStates.find(childNode) == nodeStates.end())
36         {
37             return childNode;
38         }
39         else
40         {
41             if (nodeStates.find(childNode)->second == NodeState::Visiting)
42             {
43                 return childNode;
44             }
45         }
46     }
47 
48     return {};
49 }
50 
51 template<typename TNodeId>
TopologicallySort(TNodeId initialNode,std::function<std::vector<TNodeId> (TNodeId)> getIncomingEdges,std::vector<TNodeId> & outSorted,std::map<TNodeId,NodeState> & nodeStates)52 bool TopologicallySort(
53     TNodeId initialNode,
54     std::function<std::vector<TNodeId>(TNodeId)> getIncomingEdges,
55     std::vector<TNodeId>& outSorted,
56     std::map<TNodeId, NodeState>& nodeStates)
57 {
58     std::stack<TNodeId> nodeStack;
59 
60     // If the node is never visited we should search it
61     if (nodeStates.find(initialNode) == nodeStates.end())
62     {
63         nodeStack.push(initialNode);
64     }
65 
66     while (!nodeStack.empty())
67     {
68         TNodeId current = nodeStack.top();
69 
70         nodeStates[current] = NodeState::Visiting;
71 
72         auto nextChildOfCurrent = GetNextChild(current, getIncomingEdges, nodeStates);
73 
74         if (nextChildOfCurrent)
75         {
76             TNodeId nextChild = nextChildOfCurrent.value();
77 
78             // If the child has not been searched, add to the stack and iterate over this node
79             if (nodeStates.find(nextChild) == nodeStates.end())
80             {
81                 nodeStack.push(nextChild);
82                 continue;
83             }
84 
85             // If we re-encounter a node being visited there is a cycle
86             if (nodeStates[nextChild] == NodeState::Visiting)
87             {
88                 return false;
89             }
90         }
91 
92         nodeStack.pop();
93 
94         nodeStates[current] = NodeState::Visited;
95         outSorted.push_back(current);
96     }
97 
98     return true;
99 }
100 
101 }
102 
103 // Sorts a directed acyclic graph (DAG) into a flat list such that all inputs to a node are before the node itself.
104 // Returns true if successful or false if there is an error in the graph structure (e.g. it contains a cycle).
105 // The graph is defined entirely by the "getIncomingEdges" function which the user provides. For a given node,
106 // it must return the list of nodes which are required to come before it.
107 // "targetNodes" is the list of nodes where the search begins - i.e. the nodes that you want to evaluate.
108 // This is an iterative implementation based on https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search
109 template<typename TNodeId, typename TTargetNodes>
GraphTopologicalSort(const TTargetNodes & targetNodes,std::function<std::vector<TNodeId> (TNodeId)> getIncomingEdges,std::vector<TNodeId> & outSorted)110 bool GraphTopologicalSort(
111     const TTargetNodes& targetNodes,
112     std::function<std::vector<TNodeId>(TNodeId)> getIncomingEdges,
113     std::vector<TNodeId>& outSorted)
114 {
115     outSorted.clear();
116     std::map<TNodeId, NodeState> nodeStates;
117 
118     for (TNodeId targetNode : targetNodes)
119     {
120         if (!TopologicallySort(targetNode, getIncomingEdges, outSorted, nodeStates))
121         {
122             return false;
123         }
124     }
125 
126     return true;
127 }
128 
129 }