xref: /aosp_15_r20/external/pytorch/c10/util/NetworkFlow.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/macros/Macros.h>
4 
5 #include <string>
6 #include <vector>
7 
8 /**
9  * This file provides a network flow implementation.
10  * https://en.wikipedia.org/wiki/Flow_network
11  *
12  * It aims to mirror some of the behavior of networkx, which is/was used by
13  * functorch partitioners for splitting the graph into a forward and backward
14  * graph.
15  */
16 
17 namespace c10 {
18 
19 enum class C10_API_ENUM MinCutStatus {
20   SUCCESS = 0,
21   UNBOUNDED = 1,
22   OVERFLOW_INF = 2,
23   INVALID = 3,
24 };
25 
26 struct MinCutResult {
27   MinCutStatus status;
28   int64_t max_flow;
29   std::vector<std::string> reachable;
30   std::vector<std::string> unreachable;
31 };
32 
33 // Modeled after networkx implementation
34 class C10_API NetworkFlowGraph {
35  public:
36   // selected such that INF + INF is < INT64_MAX
37   constexpr static int64_t INF = (1LL << 62) - 1;
38 
39   struct Edge {
40     std::string source, dest;
41     int64_t capacity;
42   };
43 
44   MinCutStatus add_edge(
45       const std::string& source,
46       const std::string& dest,
47       int64_t capacity = 1);
48 
49   MinCutResult minimum_cut(const std::string& s, const std::string& t) const;
50 
51   std::vector<Edge> edges;
52 };
53 
54 } // namespace c10
55