1 use std::collections::HashMap;
2 
3 use std::hash::Hash;
4 
5 use crate::algo::{BoundedMeasure, NegativeCycle};
6 use crate::visit::{
7     EdgeRef, GraphProp, IntoEdgeReferences, IntoNodeIdentifiers, NodeCompactIndexable,
8 };
9 
10 #[allow(clippy::type_complexity, clippy::needless_range_loop)]
11 /// \[Generic\] [Floyd–Warshall algorithm](https://en.wikipedia.org/wiki/Floyd%E2%80%93Warshall_algorithm) is an algorithm for all pairs shortest path problem
12 ///
13 /// Compute shortest paths in a weighted graph with positive or negative edge weights (but with no negative cycles)
14 ///
15 /// # Arguments
16 /// * `graph`: graph with no negative cycle
17 /// * `edge_cost`: closure that returns cost of a particular edge
18 ///
19 /// # Returns
20 /// * `Ok`: (if graph contains no negative cycle) a hashmap containing all pairs shortest paths
21 /// * `Err`: if graph contains negative cycle.
22 ///
23 /// # Examples
24 /// ```rust
25 /// use petgraph::{prelude::*, Graph, Directed};
26 /// use petgraph::algo::floyd_warshall;
27 /// use std::collections::HashMap;
28 ///
29 /// let mut graph: Graph<(), (), Directed> = Graph::new();
30 /// let a = graph.add_node(());
31 /// let b = graph.add_node(());
32 /// let c = graph.add_node(());
33 /// let d = graph.add_node(());
34 ///
35 /// graph.extend_with_edges(&[
36 ///    (a, b),
37 ///    (a, c),
38 ///    (a, d),
39 ///    (b, c),
40 ///    (b, d),
41 ///    (c, d)
42 /// ]);
43 ///
44 /// let weight_map: HashMap<(NodeIndex, NodeIndex), i32> = [
45 ///    ((a, a), 0), ((a, b), 1), ((a, c), 4), ((a, d), 10),
46 ///    ((b, b), 0), ((b, c), 2), ((b, d), 2),
47 ///    ((c, c), 0), ((c, d), 2)
48 /// ].iter().cloned().collect();
49 /// //     ----- b --------
50 /// //    |      ^         | 2
51 /// //    |    1 |    4    v
52 /// //  2 |      a ------> c
53 /// //    |   10 |         | 2
54 /// //    |      v         v
55 /// //     --->  d <-------
56 ///
57 /// let inf = std::i32::MAX;
58 /// let expected_res: HashMap<(NodeIndex, NodeIndex), i32> = [
59 ///    ((a, a), 0), ((a, b), 1), ((a, c), 3), ((a, d), 3),
60 ///    ((b, a), inf), ((b, b), 0), ((b, c), 2), ((b, d), 2),
61 ///    ((c, a), inf), ((c, b), inf), ((c, c), 0), ((c, d), 2),
62 ///    ((d, a), inf), ((d, b), inf), ((d, c), inf), ((d, d), 0),
63 /// ].iter().cloned().collect();
64 ///
65 ///
66 /// let res = floyd_warshall(&graph, |edge| {
67 ///     if let Some(weight) = weight_map.get(&(edge.source(), edge.target())) {
68 ///         *weight
69 ///     } else {
70 ///         inf
71 ///     }
72 /// }).unwrap();
73 ///
74 /// let nodes = [a, b, c, d];
75 /// for node1 in &nodes {
76 ///     for node2 in &nodes {
77 ///         assert_eq!(res.get(&(*node1, *node2)).unwrap(), expected_res.get(&(*node1, *node2)).unwrap());
78 ///     }
79 /// }
80 /// ```
floyd_warshall<G, F, K>( graph: G, mut edge_cost: F, ) -> Result<HashMap<(G::NodeId, G::NodeId), K>, NegativeCycle> where G: NodeCompactIndexable + IntoEdgeReferences + IntoNodeIdentifiers + GraphProp, G::NodeId: Eq + Hash, F: FnMut(G::EdgeRef) -> K, K: BoundedMeasure + Copy,81 pub fn floyd_warshall<G, F, K>(
82     graph: G,
83     mut edge_cost: F,
84 ) -> Result<HashMap<(G::NodeId, G::NodeId), K>, NegativeCycle>
85 where
86     G: NodeCompactIndexable + IntoEdgeReferences + IntoNodeIdentifiers + GraphProp,
87     G::NodeId: Eq + Hash,
88     F: FnMut(G::EdgeRef) -> K,
89     K: BoundedMeasure + Copy,
90 {
91     let num_of_nodes = graph.node_count();
92 
93     // |V|x|V| matrix
94     let mut dist = vec![vec![K::max(); num_of_nodes]; num_of_nodes];
95 
96     // init distances of paths with no intermediate nodes
97     for edge in graph.edge_references() {
98         dist[graph.to_index(edge.source())][graph.to_index(edge.target())] = edge_cost(edge);
99         if !graph.is_directed() {
100             dist[graph.to_index(edge.target())][graph.to_index(edge.source())] = edge_cost(edge);
101         }
102     }
103 
104     // distance of each node to itself is 0(default value)
105     for node in graph.node_identifiers() {
106         dist[graph.to_index(node)][graph.to_index(node)] = K::default();
107     }
108 
109     for k in 0..num_of_nodes {
110         for i in 0..num_of_nodes {
111             for j in 0..num_of_nodes {
112                 let (result, overflow) = dist[i][k].overflowing_add(dist[k][j]);
113                 if !overflow && dist[i][j] > result {
114                     dist[i][j] = result;
115                 }
116             }
117         }
118     }
119 
120     // value less than 0(default value) indicates a negative cycle
121     for i in 0..num_of_nodes {
122         if dist[i][i] < K::default() {
123             return Err(NegativeCycle(()));
124         }
125     }
126 
127     let mut distance_map: HashMap<(G::NodeId, G::NodeId), K> =
128         HashMap::with_capacity(num_of_nodes * num_of_nodes);
129 
130     for i in 0..num_of_nodes {
131         for j in 0..num_of_nodes {
132             distance_map.insert((graph.from_index(i), graph.from_index(j)), dist[i][j]);
133         }
134     }
135 
136     Ok(distance_map)
137 }
138