1 // Copyright © 2024 Mel Henning 2 // SPDX-License-Identifier: MIT 3 4 use std::collections::HashMap; 5 use std::hash::Hash; 6 7 #[derive(Copy, Clone)] 8 struct Root<X: Copy> { 9 size: usize, 10 representative: X, 11 } 12 13 #[derive(Copy, Clone)] 14 enum Node<X: Copy> { 15 Child { parent_idx: usize }, 16 Root(Root<X>), 17 } 18 19 /// Union-find structure 20 /// 21 /// This implementation follows Tarjan and van Leeuwen - specifically the 22 /// "link by size" and "halving" variant. 23 /// 24 /// Robert E. Tarjan and Jan van Leeuwen. 1984. Worst-case Analysis of Set 25 /// Union Algorithms. J. ACM 31, 2 (April 1984), 245–281. 26 /// https://doi.org/10.1145/62.2160 27 pub struct UnionFind<X: Copy + Hash + Eq> { 28 idx_map: HashMap<X, usize>, 29 nodes: Vec<Node<X>>, 30 } 31 32 impl<X: Copy + Hash + Eq> UnionFind<X> { 33 /// Create a new union-find structure 34 /// 35 /// At initialization, each possible value is in its own set new() -> Self36 pub fn new() -> Self { 37 UnionFind { 38 idx_map: HashMap::new(), 39 nodes: Vec::new(), 40 } 41 } 42 find_root(&mut self, mut idx: usize) -> (usize, Root<X>)43 fn find_root(&mut self, mut idx: usize) -> (usize, Root<X>) { 44 loop { 45 match self.nodes[idx] { 46 Node::Child { parent_idx } => { 47 match self.nodes[parent_idx] { 48 Node::Child { 49 parent_idx: grandparent_idx, 50 } => { 51 // "Halving" in Tarjan and van Leeuwen 52 self.nodes[idx] = Node::Child { 53 parent_idx: grandparent_idx, 54 }; 55 idx = grandparent_idx; 56 } 57 Node::Root(parent_root) => { 58 return (parent_idx, parent_root) 59 } 60 } 61 } 62 Node::Root(root) => return (idx, root), 63 } 64 } 65 } 66 67 /// Find the representative element for x find(&mut self, x: X) -> X68 pub fn find(&mut self, x: X) -> X { 69 match self.idx_map.get(&x) { 70 Some(&idx) => { 71 let (_, Root { representative, .. }) = self.find_root(idx); 72 representative 73 } 74 None => x, 75 } 76 } 77 map_or_create(&mut self, x: X) -> usize78 fn map_or_create(&mut self, x: X) -> usize { 79 *self.idx_map.entry(x).or_insert_with(|| { 80 self.nodes.push(Node::Root(Root { 81 size: 1, 82 representative: x, 83 })); 84 self.nodes.len() - 1 85 }) 86 } 87 88 /// Union the sets containing a and b 89 /// 90 /// The representative for a will become the representative of 91 /// the combined set union(&mut self, a: X, b: X)92 pub fn union(&mut self, a: X, b: X) { 93 if a == b { 94 return; 95 } 96 97 let a_idx = self.map_or_create(a); 98 let b_idx = self.map_or_create(b); 99 let (a_root_idx, a_root) = self.find_root(a_idx); 100 let (b_root_idx, b_root) = self.find_root(b_idx); 101 102 if a_root_idx != b_root_idx { 103 // Keep the tree balanced 104 let (new_root_idx, new_child_idx) = if a_root.size >= b_root.size { 105 (a_root_idx, b_root_idx) 106 } else { 107 (b_root_idx, a_root_idx) 108 }; 109 110 self.nodes[new_root_idx] = Node::Root(Root { 111 size: a_root.size + b_root.size, 112 representative: a_root.representative, 113 }); 114 self.nodes[new_child_idx] = Node::Child { 115 parent_idx: new_root_idx, 116 }; 117 } 118 } 119 120 /// Return true if find() is the identity mapping is_empty(&self) -> bool121 pub fn is_empty(&self) -> bool { 122 self.nodes.is_empty() 123 } 124 } 125 126 #[cfg(test)] 127 mod tests { 128 use crate::union_find::Node; 129 use crate::union_find::UnionFind; 130 use std::cmp::max; 131 use std::hash::Hash; 132 ceil_log2(x: usize) -> u32133 fn ceil_log2(x: usize) -> u32 { 134 assert!(x > 0); 135 usize::BITS - (x - 1).leading_zeros() 136 } 137 138 struct HeightInfo { 139 height: u32, 140 size: usize, 141 } 142 143 pub struct HeightCalc<'a, X: Copy + Hash + Eq> { 144 uf: &'a UnionFind<X>, 145 downward_edges: Vec<Vec<usize>>, 146 } 147 148 impl<'a, X: Copy + Hash + Eq> HeightCalc<'a, X> { new(uf: &'a UnionFind<X>) -> Self149 fn new(uf: &'a UnionFind<X>) -> Self { 150 let mut downward_edges: Vec<Vec<usize>> = 151 uf.nodes.iter().map(|_| Vec::new()).collect(); 152 for (i, node) in uf.nodes.iter().enumerate() { 153 if let Node::Child { parent_idx } = node { 154 downward_edges[*parent_idx].push(i); 155 } 156 } 157 158 HeightCalc { uf, downward_edges } 159 } 160 calc_info(&self, idx: usize) -> HeightInfo161 fn calc_info(&self, idx: usize) -> HeightInfo { 162 let mut result = HeightInfo { height: 0, size: 1 }; 163 for child in &self.downward_edges[idx] { 164 let child_result = self.calc_info(*child); 165 result.height = max(result.height, child_result.height + 1); 166 result.size += child_result.size; 167 } 168 result 169 } 170 check_roots(&self) -> u32171 fn check_roots(&self) -> u32 { 172 let mut total_size = 0; 173 let mut max_height = 0; 174 for (i, node) in self.uf.nodes.iter().enumerate() { 175 if let Node::Root(root) = node { 176 let info = self.calc_info(i); 177 assert_eq!(root.size, info.size); 178 179 total_size += info.size; 180 max_height = max(max_height, info.height); 181 182 let max_expected_height = ceil_log2(root.size + 1) - 1; 183 if info.height > max_expected_height { 184 eprintln!( 185 "height {}\t max_expected_height {}\t size {}", 186 info.height, max_expected_height, info.size 187 ); 188 } 189 assert!(info.height <= max_expected_height); 190 } 191 } 192 assert_eq!(total_size, self.uf.nodes.len()); 193 assert_eq!(total_size, self.uf.idx_map.len()); 194 return max_height; 195 } 196 check(uf: &'a UnionFind<X>) -> u32197 pub fn check(uf: &'a UnionFind<X>) -> u32 { 198 HeightCalc::new(uf).check_roots() 199 } 200 } 201 202 #[test] test_basic()203 fn test_basic() { 204 let mut f = UnionFind::new(); 205 assert_eq!(f.find(10), 10); 206 assert_eq!(f.find(12), 12); 207 208 f.union(10, 12); 209 f.union(11, 13); 210 211 HeightCalc::check(&f); 212 213 assert_eq!(f.find(13), 11); 214 assert_eq!(f.find(12), 10); 215 assert_eq!(f.find(11), 11); 216 assert_eq!(f.find(10), 10); 217 218 f.union(12, 13); 219 220 HeightCalc::check(&f); 221 222 assert_eq!(f.find(13), 10); 223 assert_eq!(f.find(12), 10); 224 assert_eq!(f.find(11), 10); 225 assert_eq!(f.find(10), 10); 226 227 assert_eq!(f.find(14), 14); 228 229 HeightCalc::check(&f); 230 231 // Union the set with itself 232 f.union(11, 10); 233 234 HeightCalc::check(&f); 235 236 assert_eq!(f.find(13), 10); 237 assert_eq!(f.find(12), 10); 238 assert_eq!(f.find(11), 10); 239 assert_eq!(f.find(10), 10); 240 } 241 242 #[test] test_chain_a_height()243 fn test_chain_a_height() { 244 let mut f = UnionFind::new(); 245 for i in 0..1000 { 246 f.union(i, i + 1); 247 HeightCalc::check(&f); 248 } 249 assert_eq!(f.find(1000), 0); 250 } 251 252 #[test] test_chain_b_height()253 fn test_chain_b_height() { 254 let mut f = UnionFind::new(); 255 for i in 0..1000 { 256 f.union(i + 1, i); 257 HeightCalc::check(&f); 258 } 259 assert_eq!(f.find(0), 1000); 260 } 261 262 #[test] test_binary_tree_height()263 fn test_binary_tree_height() { 264 let height = 8; 265 let count = 1 << height; 266 267 let mut f = UnionFind::new(); 268 for current_height in 0..height { 269 let stride = 1 << current_height; 270 for i in (0..count).step_by(2 * stride) { 271 f.union(i, i + stride); 272 } 273 let actual_height = HeightCalc::check(&f); 274 275 // actual_height can vary based on tiebreaker condition 276 assert!( 277 actual_height == current_height 278 || actual_height == current_height + 1 279 ); 280 } 281 282 // Check path halving 283 let actual_height_before = HeightCalc::check(&f); 284 for i in 0..count { 285 assert_eq!(f.find(i), 0); 286 } 287 let actual_height_after = HeightCalc::check(&f); 288 289 assert!(actual_height_after <= actual_height_before.div_ceil(2)); 290 } 291 } 292