xref: /aosp_15_r20/external/mesa3d/src/nouveau/compiler/nak/union_find.rs (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 // Copyright © 2024 Mel Henning
2 // SPDX-License-Identifier: MIT
3 
4 use std::collections::HashMap;
5 use std::hash::Hash;
6 
7 #[derive(Copy, Clone)]
8 struct Root<X: Copy> {
9     size: usize,
10     representative: X,
11 }
12 
13 #[derive(Copy, Clone)]
14 enum Node<X: Copy> {
15     Child { parent_idx: usize },
16     Root(Root<X>),
17 }
18 
19 /// Union-find structure
20 ///
21 /// This implementation follows Tarjan and van Leeuwen - specifically the
22 /// "link by size" and "halving" variant.
23 ///
24 /// Robert E. Tarjan and Jan van Leeuwen. 1984. Worst-case Analysis of Set
25 ///     Union Algorithms. J. ACM 31, 2 (April 1984), 245–281.
26 ///     https://doi.org/10.1145/62.2160
27 pub struct UnionFind<X: Copy + Hash + Eq> {
28     idx_map: HashMap<X, usize>,
29     nodes: Vec<Node<X>>,
30 }
31 
32 impl<X: Copy + Hash + Eq> UnionFind<X> {
33     /// Create a new union-find structure
34     ///
35     /// At initialization, each possible value is in its own set
new() -> Self36     pub fn new() -> Self {
37         UnionFind {
38             idx_map: HashMap::new(),
39             nodes: Vec::new(),
40         }
41     }
42 
find_root(&mut self, mut idx: usize) -> (usize, Root<X>)43     fn find_root(&mut self, mut idx: usize) -> (usize, Root<X>) {
44         loop {
45             match self.nodes[idx] {
46                 Node::Child { parent_idx } => {
47                     match self.nodes[parent_idx] {
48                         Node::Child {
49                             parent_idx: grandparent_idx,
50                         } => {
51                             // "Halving" in Tarjan and van Leeuwen
52                             self.nodes[idx] = Node::Child {
53                                 parent_idx: grandparent_idx,
54                             };
55                             idx = grandparent_idx;
56                         }
57                         Node::Root(parent_root) => {
58                             return (parent_idx, parent_root)
59                         }
60                     }
61                 }
62                 Node::Root(root) => return (idx, root),
63             }
64         }
65     }
66 
67     /// Find the representative element for x
find(&mut self, x: X) -> X68     pub fn find(&mut self, x: X) -> X {
69         match self.idx_map.get(&x) {
70             Some(&idx) => {
71                 let (_, Root { representative, .. }) = self.find_root(idx);
72                 representative
73             }
74             None => x,
75         }
76     }
77 
map_or_create(&mut self, x: X) -> usize78     fn map_or_create(&mut self, x: X) -> usize {
79         *self.idx_map.entry(x).or_insert_with(|| {
80             self.nodes.push(Node::Root(Root {
81                 size: 1,
82                 representative: x,
83             }));
84             self.nodes.len() - 1
85         })
86     }
87 
88     /// Union the sets containing a and b
89     ///
90     /// The representative for a will become the representative of
91     /// the combined set
union(&mut self, a: X, b: X)92     pub fn union(&mut self, a: X, b: X) {
93         if a == b {
94             return;
95         }
96 
97         let a_idx = self.map_or_create(a);
98         let b_idx = self.map_or_create(b);
99         let (a_root_idx, a_root) = self.find_root(a_idx);
100         let (b_root_idx, b_root) = self.find_root(b_idx);
101 
102         if a_root_idx != b_root_idx {
103             // Keep the tree balanced
104             let (new_root_idx, new_child_idx) = if a_root.size >= b_root.size {
105                 (a_root_idx, b_root_idx)
106             } else {
107                 (b_root_idx, a_root_idx)
108             };
109 
110             self.nodes[new_root_idx] = Node::Root(Root {
111                 size: a_root.size + b_root.size,
112                 representative: a_root.representative,
113             });
114             self.nodes[new_child_idx] = Node::Child {
115                 parent_idx: new_root_idx,
116             };
117         }
118     }
119 
120     /// Return true if find() is the identity mapping
is_empty(&self) -> bool121     pub fn is_empty(&self) -> bool {
122         self.nodes.is_empty()
123     }
124 }
125 
126 #[cfg(test)]
127 mod tests {
128     use crate::union_find::Node;
129     use crate::union_find::UnionFind;
130     use std::cmp::max;
131     use std::hash::Hash;
132 
ceil_log2(x: usize) -> u32133     fn ceil_log2(x: usize) -> u32 {
134         assert!(x > 0);
135         usize::BITS - (x - 1).leading_zeros()
136     }
137 
138     struct HeightInfo {
139         height: u32,
140         size: usize,
141     }
142 
143     pub struct HeightCalc<'a, X: Copy + Hash + Eq> {
144         uf: &'a UnionFind<X>,
145         downward_edges: Vec<Vec<usize>>,
146     }
147 
148     impl<'a, X: Copy + Hash + Eq> HeightCalc<'a, X> {
new(uf: &'a UnionFind<X>) -> Self149         fn new(uf: &'a UnionFind<X>) -> Self {
150             let mut downward_edges: Vec<Vec<usize>> =
151                 uf.nodes.iter().map(|_| Vec::new()).collect();
152             for (i, node) in uf.nodes.iter().enumerate() {
153                 if let Node::Child { parent_idx } = node {
154                     downward_edges[*parent_idx].push(i);
155                 }
156             }
157 
158             HeightCalc { uf, downward_edges }
159         }
160 
calc_info(&self, idx: usize) -> HeightInfo161         fn calc_info(&self, idx: usize) -> HeightInfo {
162             let mut result = HeightInfo { height: 0, size: 1 };
163             for child in &self.downward_edges[idx] {
164                 let child_result = self.calc_info(*child);
165                 result.height = max(result.height, child_result.height + 1);
166                 result.size += child_result.size;
167             }
168             result
169         }
170 
check_roots(&self) -> u32171         fn check_roots(&self) -> u32 {
172             let mut total_size = 0;
173             let mut max_height = 0;
174             for (i, node) in self.uf.nodes.iter().enumerate() {
175                 if let Node::Root(root) = node {
176                     let info = self.calc_info(i);
177                     assert_eq!(root.size, info.size);
178 
179                     total_size += info.size;
180                     max_height = max(max_height, info.height);
181 
182                     let max_expected_height = ceil_log2(root.size + 1) - 1;
183                     if info.height > max_expected_height {
184                         eprintln!(
185                             "height {}\t max_expected_height {}\t size {}",
186                             info.height, max_expected_height, info.size
187                         );
188                     }
189                     assert!(info.height <= max_expected_height);
190                 }
191             }
192             assert_eq!(total_size, self.uf.nodes.len());
193             assert_eq!(total_size, self.uf.idx_map.len());
194             return max_height;
195         }
196 
check(uf: &'a UnionFind<X>) -> u32197         pub fn check(uf: &'a UnionFind<X>) -> u32 {
198             HeightCalc::new(uf).check_roots()
199         }
200     }
201 
202     #[test]
test_basic()203     fn test_basic() {
204         let mut f = UnionFind::new();
205         assert_eq!(f.find(10), 10);
206         assert_eq!(f.find(12), 12);
207 
208         f.union(10, 12);
209         f.union(11, 13);
210 
211         HeightCalc::check(&f);
212 
213         assert_eq!(f.find(13), 11);
214         assert_eq!(f.find(12), 10);
215         assert_eq!(f.find(11), 11);
216         assert_eq!(f.find(10), 10);
217 
218         f.union(12, 13);
219 
220         HeightCalc::check(&f);
221 
222         assert_eq!(f.find(13), 10);
223         assert_eq!(f.find(12), 10);
224         assert_eq!(f.find(11), 10);
225         assert_eq!(f.find(10), 10);
226 
227         assert_eq!(f.find(14), 14);
228 
229         HeightCalc::check(&f);
230 
231         // Union the set with itself
232         f.union(11, 10);
233 
234         HeightCalc::check(&f);
235 
236         assert_eq!(f.find(13), 10);
237         assert_eq!(f.find(12), 10);
238         assert_eq!(f.find(11), 10);
239         assert_eq!(f.find(10), 10);
240     }
241 
242     #[test]
test_chain_a_height()243     fn test_chain_a_height() {
244         let mut f = UnionFind::new();
245         for i in 0..1000 {
246             f.union(i, i + 1);
247             HeightCalc::check(&f);
248         }
249         assert_eq!(f.find(1000), 0);
250     }
251 
252     #[test]
test_chain_b_height()253     fn test_chain_b_height() {
254         let mut f = UnionFind::new();
255         for i in 0..1000 {
256             f.union(i + 1, i);
257             HeightCalc::check(&f);
258         }
259         assert_eq!(f.find(0), 1000);
260     }
261 
262     #[test]
test_binary_tree_height()263     fn test_binary_tree_height() {
264         let height = 8;
265         let count = 1 << height;
266 
267         let mut f = UnionFind::new();
268         for current_height in 0..height {
269             let stride = 1 << current_height;
270             for i in (0..count).step_by(2 * stride) {
271                 f.union(i, i + stride);
272             }
273             let actual_height = HeightCalc::check(&f);
274 
275             // actual_height can vary based on tiebreaker condition
276             assert!(
277                 actual_height == current_height
278                     || actual_height == current_height + 1
279             );
280         }
281 
282         // Check path halving
283         let actual_height_before = HeightCalc::check(&f);
284         for i in 0..count {
285             assert_eq!(f.find(i), 0);
286         }
287         let actual_height_after = HeightCalc::check(&f);
288 
289         assert!(actual_height_after <= actual_height_before.div_ceil(2));
290     }
291 }
292