xref: /aosp_15_r20/external/pytorch/c10/util/NetworkFlow.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/NetworkFlow.h>
2 
3 #include <c10/util/Exception.h>
4 
5 #include <iostream>
6 #include <optional>
7 #include <queue>
8 #include <unordered_map>
9 #include <vector>
10 
11 namespace c10 {
12 
13 namespace {
14 
15 struct DinicFlowGraph {
16   // [Note: Dinic graph format]
17   // The graph is represented as an adjacency list:
18   //   for a vertex u, adj[u] lists all the outgoing edges from u.
19   //   adj[u][i] is the index of the i-th outgoing edge from u.
20   //   To get information on the i-th outgoing edge from u, use
21   //   edges[adj[i][i]].
22   // The edges are directed and are paired with a reverse edge.
23   //   For example, an edge u->v is paired with a v->u edge.
24   //   The index of the reverse edge of e is stored as e.other_idx.
25   // Capacities and flows: each edge has a capacity and a flow
26   //   associated with it. When flow is added to an edge, it removes
27   //   capacity from the reverse edge.
28   struct Edge {
29     size_t u, v;
30     int64_t capacity;
31     int64_t flow;
32     size_t other_idx; // reverse edge
33 
residual_capacityc10::__anonc86546720111::DinicFlowGraph::Edge34     int64_t residual_capacity() const {
35       return capacity - flow;
36     }
37   };
38 
39   std::vector<Edge> edges;
40   std::vector<std::vector<size_t>> adj; // adjacency list
41   std::vector<std::string> vertex_names;
42   std::unordered_map<std::string, size_t> mapping;
43   size_t graph_size;
44 
add_flowc10::__anonc86546720111::DinicFlowGraph45   void add_flow(Edge& e, int64_t more) {
46     e.flow += more;
47     edges[e.other_idx].flow -= more;
48   }
49 
reverse_edgec10::__anonc86546720111::DinicFlowGraph50   const Edge& reverse_edge(const Edge& e) const {
51     return edges[e.other_idx];
52   }
53 
DinicFlowGraphc10::__anonc86546720111::DinicFlowGraph54   DinicFlowGraph(const NetworkFlowGraph& g) {
55     size_t vertex_count = 0;
56 
57     auto get_idx = [&vertex_count, this](const std::string& name) {
58       if (!mapping.count(name)) {
59         TORCH_CHECK(vertex_count == vertex_names.size());
60         vertex_names.push_back(name);
61         size_t idx = vertex_count;
62         vertex_count++;
63         mapping[name] = idx;
64         return idx;
65       }
66       return mapping[name];
67     };
68 
69     for (const auto& [source, dest, capacity] : g.edges) {
70       auto u = get_idx(source);
71       auto v = get_idx(dest);
72       auto fwd_idx = edges.size();
73       auto bwd_idx = edges.size() + 1;
74       edges.push_back({u, v, capacity, 0, bwd_idx});
75       edges.push_back({v, u, 0, 0, fwd_idx});
76     }
77 
78     // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
79     graph_size = mapping.size();
80     adj.resize(graph_size);
81 
82     for (size_t i = 0; i < edges.size(); ++i) {
83       adj[edges[i].u].push_back(i);
84     }
85   }
86 
residual_level_graphc10::__anonc86546720111::DinicFlowGraph87   std::vector<std::vector<size_t>> residual_level_graph(size_t s) const {
88     // The residual graph is the graph including only edges
89     //   where edge.residual_capacity() is nonzero, i.e.
90     //   edge.capacity > edge.flow.
91     // The residual level graph is constructed by:
92     //   1. doing a BFS on the residual graph, assigning levels
93     //      to each vertex.
94     //   2. only include edges u->v where level[v] == leve[u] + 1
95     std::queue<size_t> q;
96     // let level[u] = 0 if it has not been visited yet.
97     std::vector<size_t> level(graph_size, 0);
98     // TODO(davidberard98) we can create this once and reuse it
99     std::vector<std::vector<size_t>> output_adjacency(graph_size);
100     level[s] = 1;
101     q.push(s);
102     while (!q.empty()) {
103       size_t u = q.front();
104       q.pop();
105       for (const auto& edge_idx : adj[u]) {
106         const auto& e = edges[edge_idx];
107         if (e.residual_capacity()) {
108           if (level[e.v] == 0) {
109             level[e.v] = level[e.u] + 1;
110             q.push(e.v);
111           }
112           if (level[e.v] == level[e.u] + 1) {
113             output_adjacency[e.u].push_back(edge_idx);
114           }
115         }
116       }
117     }
118 
119     return output_adjacency;
120   }
121 
augment_iterationc10::__anonc86546720111::DinicFlowGraph122   std::pair<MinCutStatus, int64_t> augment_iteration(size_t s, size_t t) {
123     // Perform one iteration of augmenting the flow.
124     // 1. Create the level graph
125     // 2. DFS to find augmenting paths
126     // 3. If encountering edges that don't lead to augmenting paths,
127     //    trim them from the level graph.
128     // 4. Repeat 2-3 until we can't find any augmenting paths.
129     std::vector<std::vector<size_t>> level_adj = residual_level_graph(s);
130 
131     // TODO(davidberard98): implement this DFS with a stack
132     std::function<int64_t(size_t, size_t, int64_t)> dfs;
133     dfs = [&level_adj, &dfs, this](
134               size_t u, size_t t, int64_t cur_cap) -> int64_t {
135       if (u == t) {
136         return cur_cap;
137       }
138       while (!level_adj[u].empty()) {
139         // Iterate over the outgoing edges from u.
140         // If take an edge and find that we can't augment using this edge,
141         //   then delete it from our level graph.
142         // If we take an edge and it does find an augmenting path, then
143         //   take the augmenting path and exit early
144         auto edge_idx = level_adj[u].back();
145         auto& e = edges[edge_idx];
146         auto taken_cap = dfs(e.v, t, std::min(cur_cap, e.residual_capacity()));
147         if (taken_cap) {
148           add_flow(e, taken_cap);
149           if (!e.residual_capacity()) {
150             // this edge has no remaining residual capacity, remove it.
151             level_adj[u].pop_back();
152           }
153           return taken_cap;
154         } else {
155           // we can't get any capacity from this edge, remove it.
156           level_adj[u].pop_back();
157         }
158       }
159       return 0;
160     };
161 
162     int64_t additional_flow = 0;
163     while (int64_t f = dfs(s, t, NetworkFlowGraph::INF)) {
164       if (f == NetworkFlowGraph::INF) {
165         return {MinCutStatus::UNBOUNDED, 0};
166       }
167       additional_flow += f;
168       if (additional_flow >= NetworkFlowGraph::INF) {
169         return {MinCutStatus::OVERFLOW_INF, 0};
170       }
171     }
172 
173     return {MinCutStatus::SUCCESS, additional_flow};
174   }
175 
compute_max_flowc10::__anonc86546720111::DinicFlowGraph176   std::pair<MinCutStatus, int64_t> compute_max_flow(size_t s, size_t t) {
177     int64_t total_flow = 0;
178     while (true) {
179       auto [status, additional_flow] = augment_iteration(s, t);
180       if (status != MinCutStatus::SUCCESS) {
181         return {status, 0};
182       }
183       if (additional_flow == 0) {
184         break;
185       }
186       total_flow += additional_flow;
187       if (total_flow >= NetworkFlowGraph::INF) {
188         return {MinCutStatus::OVERFLOW_INF, 0};
189       }
190     }
191     return {MinCutStatus::SUCCESS, total_flow};
192   }
193 
reverse_bfs_reachablec10::__anonc86546720111::DinicFlowGraph194   std::vector<bool> reverse_bfs_reachable(size_t t) const {
195     // Find all vertices that are reachable from t in the reverse
196     //   residual graph.
197     std::vector<bool> seen(graph_size, false);
198     seen[t] = true;
199     std::queue<size_t> q;
200     q.push(t);
201     while (!q.empty()) {
202       auto x = q.front();
203       q.pop();
204       for (auto& edge_idx : adj[x]) {
205         // the edge that goes u -> v where v == x
206         const auto& e = reverse_edge(edges[edge_idx]);
207         if (!e.residual_capacity()) {
208           continue;
209         }
210 
211         if (!seen[e.u]) {
212           seen[e.u] = true;
213           q.push(e.u);
214         }
215       }
216     }
217     return seen;
218   }
219 
partitionc10::__anonc86546720111::DinicFlowGraph220   std::pair<std::vector<size_t>, std::vector<size_t>> partition(
221       size_t s,
222       size_t t) {
223     // Note: the partitioning returns "reachable" / "unreachable",
224     //   but specifically, for "unreachable", it returns "all vertices
225     //   that are reachable from t in the reverse residual graph"
226     //   and for "reachable" it returns all other nodes. This mirrors
227     //   the behavior of networkx.
228     auto can_reach_t = reverse_bfs_reachable(t);
229     std::vector<size_t> reachable, unreachable;
230     for (size_t i = 0; i < graph_size; ++i) {
231       if (can_reach_t[i]) {
232         unreachable.push_back(i);
233       } else {
234         reachable.push_back(i);
235       }
236     }
237     return std::pair<std::vector<size_t>, std::vector<size_t>>(
238         std::move(reachable), std::move(unreachable));
239   }
240 
minimum_cutc10::__anonc86546720111::DinicFlowGraph241   MinCutResult minimum_cut(const std::string& s, const std::string& t) {
242     if (mapping.find(s) == mapping.end() || mapping.find(t) == mapping.end()) {
243       return {
244           MinCutStatus::INVALID, // status
245           0, // max_flow
246           {}, // reachable
247           {}, // unreachable
248       };
249     }
250     auto s_int = mapping[s];
251     auto t_int = mapping[t];
252     auto [status, max_flow] = compute_max_flow(s_int, t_int);
253     if (status != MinCutStatus::SUCCESS) {
254       return {
255           status, // status
256           0, // max_flow
257           {}, // reachable
258           {}, // unreachable
259       };
260     }
261 
262     auto [reachable_idxs, unreachable_idxs] = partition(s_int, t_int);
263     std::vector<std::string> reachable, unreachable;
264 
265     auto idxs_to_names = [&](std::vector<size_t>& src,
266                              std::vector<std::string>& dest) {
267       dest.reserve(src.size());
268       for (auto idx : src) {
269         dest.push_back(vertex_names[idx]);
270       }
271     };
272 
273     idxs_to_names(reachable_idxs, reachable);
274     idxs_to_names(unreachable_idxs, unreachable);
275 
276     return {
277         MinCutStatus::SUCCESS,
278         max_flow,
279         reachable,
280         unreachable,
281     };
282   }
283 };
284 
285 } // namespace
286 
add_edge(const std::string & source,const std::string & dest,int64_t capacity)287 MinCutStatus NetworkFlowGraph::add_edge(
288     const std::string& source,
289     const std::string& dest,
290     int64_t capacity) {
291   edges.push_back({source, dest, capacity});
292   return MinCutStatus::SUCCESS;
293 }
294 
minimum_cut(const std::string & s,const std::string & t) const295 MinCutResult NetworkFlowGraph::minimum_cut(
296     const std::string& s,
297     const std::string& t) const {
298   auto flow_graph = DinicFlowGraph(*this);
299 
300   return flow_graph.minimum_cut(s, t);
301 }
302 
303 } // namespace c10
304