1 use std::collections::hash_map::Entry::{Occupied, Vacant};
2 use std::collections::{BinaryHeap, HashMap};
3 
4 use std::hash::Hash;
5 
6 use crate::scored::MinScored;
7 use crate::visit::{EdgeRef, GraphBase, IntoEdges, Visitable};
8 
9 use crate::algo::Measure;
10 
11 /// \[Generic\] A* shortest path algorithm.
12 ///
13 /// Computes the shortest path from `start` to `finish`, including the total path cost.
14 ///
15 /// `finish` is implicitly given via the `is_goal` callback, which should return `true` if the
16 /// given node is the finish node.
17 ///
18 /// The function `edge_cost` should return the cost for a particular edge. Edge costs must be
19 /// non-negative.
20 ///
21 /// The function `estimate_cost` should return the estimated cost to the finish for a particular
22 /// node. For the algorithm to find the actual shortest path, it should be admissible, meaning that
23 /// it should never overestimate the actual cost to get to the nearest goal node. Estimate costs
24 /// must also be non-negative.
25 ///
26 /// The graph should be `Visitable` and implement `IntoEdges`.
27 ///
28 /// # Example
29 /// ```
30 /// use petgraph::Graph;
31 /// use petgraph::algo::astar;
32 ///
33 /// let mut g = Graph::new();
34 /// let a = g.add_node((0., 0.));
35 /// let b = g.add_node((2., 0.));
36 /// let c = g.add_node((1., 1.));
37 /// let d = g.add_node((0., 2.));
38 /// let e = g.add_node((3., 3.));
39 /// let f = g.add_node((4., 2.));
40 /// g.extend_with_edges(&[
41 ///     (a, b, 2),
42 ///     (a, d, 4),
43 ///     (b, c, 1),
44 ///     (b, f, 7),
45 ///     (c, e, 5),
46 ///     (e, f, 1),
47 ///     (d, e, 1),
48 /// ]);
49 ///
50 /// // Graph represented with the weight of each edge
51 /// // Edges with '*' are part of the optimal path.
52 /// //
53 /// //     2       1
54 /// // a ----- b ----- c
55 /// // | 4*    | 7     |
56 /// // d       f       | 5
57 /// // | 1*    | 1*    |
58 /// // \------ e ------/
59 ///
60 /// let path = astar(&g, a, |finish| finish == f, |e| *e.weight(), |_| 0);
61 /// assert_eq!(path, Some((6, vec![a, d, e, f])));
62 /// ```
63 ///
64 /// Returns the total cost + the path of subsequent `NodeId` from start to finish, if one was
65 /// found.
astar<G, F, H, K, IsGoal>( graph: G, start: G::NodeId, mut is_goal: IsGoal, mut edge_cost: F, mut estimate_cost: H, ) -> Option<(K, Vec<G::NodeId>)> where G: IntoEdges + Visitable, IsGoal: FnMut(G::NodeId) -> bool, G::NodeId: Eq + Hash, F: FnMut(G::EdgeRef) -> K, H: FnMut(G::NodeId) -> K, K: Measure + Copy,66 pub fn astar<G, F, H, K, IsGoal>(
67     graph: G,
68     start: G::NodeId,
69     mut is_goal: IsGoal,
70     mut edge_cost: F,
71     mut estimate_cost: H,
72 ) -> Option<(K, Vec<G::NodeId>)>
73 where
74     G: IntoEdges + Visitable,
75     IsGoal: FnMut(G::NodeId) -> bool,
76     G::NodeId: Eq + Hash,
77     F: FnMut(G::EdgeRef) -> K,
78     H: FnMut(G::NodeId) -> K,
79     K: Measure + Copy,
80 {
81     let mut visit_next = BinaryHeap::new();
82     let mut scores = HashMap::new(); // g-values, cost to reach the node
83     let mut estimate_scores = HashMap::new(); // f-values, cost to reach + estimate cost to goal
84     let mut path_tracker = PathTracker::<G>::new();
85 
86     let zero_score = K::default();
87     scores.insert(start, zero_score);
88     visit_next.push(MinScored(estimate_cost(start), start));
89 
90     while let Some(MinScored(estimate_score, node)) = visit_next.pop() {
91         if is_goal(node) {
92             let path = path_tracker.reconstruct_path_to(node);
93             let cost = scores[&node];
94             return Some((cost, path));
95         }
96 
97         // This lookup can be unwrapped without fear of panic since the node was necessarily scored
98         // before adding it to `visit_next`.
99         let node_score = scores[&node];
100 
101         match estimate_scores.entry(node) {
102             Occupied(mut entry) => {
103                 // If the node has already been visited with an equal or lower score than now, then
104                 // we do not need to re-visit it.
105                 if *entry.get() <= estimate_score {
106                     continue;
107                 }
108                 entry.insert(estimate_score);
109             }
110             Vacant(entry) => {
111                 entry.insert(estimate_score);
112             }
113         }
114 
115         for edge in graph.edges(node) {
116             let next = edge.target();
117             let next_score = node_score + edge_cost(edge);
118 
119             match scores.entry(next) {
120                 Occupied(mut entry) => {
121                     // No need to add neighbors that we have already reached through a shorter path
122                     // than now.
123                     if *entry.get() <= next_score {
124                         continue;
125                     }
126                     entry.insert(next_score);
127                 }
128                 Vacant(entry) => {
129                     entry.insert(next_score);
130                 }
131             }
132 
133             path_tracker.set_predecessor(next, node);
134             let next_estimate_score = next_score + estimate_cost(next);
135             visit_next.push(MinScored(next_estimate_score, next));
136         }
137     }
138 
139     None
140 }
141 
142 struct PathTracker<G>
143 where
144     G: GraphBase,
145     G::NodeId: Eq + Hash,
146 {
147     came_from: HashMap<G::NodeId, G::NodeId>,
148 }
149 
150 impl<G> PathTracker<G>
151 where
152     G: GraphBase,
153     G::NodeId: Eq + Hash,
154 {
new() -> PathTracker<G>155     fn new() -> PathTracker<G> {
156         PathTracker {
157             came_from: HashMap::new(),
158         }
159     }
160 
set_predecessor(&mut self, node: G::NodeId, previous: G::NodeId)161     fn set_predecessor(&mut self, node: G::NodeId, previous: G::NodeId) {
162         self.came_from.insert(node, previous);
163     }
164 
reconstruct_path_to(&self, last: G::NodeId) -> Vec<G::NodeId>165     fn reconstruct_path_to(&self, last: G::NodeId) -> Vec<G::NodeId> {
166         let mut path = vec![last];
167 
168         let mut current = last;
169         while let Some(&previous) = self.came_from.get(&current) {
170             path.push(previous);
171             current = previous;
172         }
173 
174         path.reverse();
175 
176         path
177     }
178 }
179