1 use std::fmt::{self, Debug}; 2 3 use super::chunks::ChunkProducer; 4 use super::plumbing::*; 5 use super::*; 6 use crate::math::div_round_up; 7 8 /// `FoldChunks` is an iterator that groups elements of an underlying iterator and applies a 9 /// function over them, producing a single value for each group. 10 /// 11 /// This struct is created by the [`fold_chunks()`] method on [`IndexedParallelIterator`] 12 /// 13 /// [`fold_chunks()`]: trait.IndexedParallelIterator.html#method.fold_chunks 14 /// [`IndexedParallelIterator`]: trait.IndexedParallelIterator.html 15 #[must_use = "iterator adaptors are lazy and do nothing unless consumed"] 16 #[derive(Clone)] 17 pub struct FoldChunks<I, ID, F> 18 where 19 I: IndexedParallelIterator, 20 { 21 base: I, 22 chunk_size: usize, 23 fold_op: F, 24 identity: ID, 25 } 26 27 impl<I: IndexedParallelIterator + Debug, ID, F> Debug for FoldChunks<I, ID, F> { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result28 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 29 f.debug_struct("Fold") 30 .field("base", &self.base) 31 .field("chunk_size", &self.chunk_size) 32 .finish() 33 } 34 } 35 36 impl<I, ID, U, F> FoldChunks<I, ID, F> 37 where 38 I: IndexedParallelIterator, 39 ID: Fn() -> U + Send + Sync, 40 F: Fn(U, I::Item) -> U + Send + Sync, 41 U: Send, 42 { 43 /// Creates a new `FoldChunks` iterator new(base: I, chunk_size: usize, identity: ID, fold_op: F) -> Self44 pub(super) fn new(base: I, chunk_size: usize, identity: ID, fold_op: F) -> Self { 45 FoldChunks { 46 base, 47 chunk_size, 48 identity, 49 fold_op, 50 } 51 } 52 } 53 54 impl<I, ID, U, F> ParallelIterator for FoldChunks<I, ID, F> 55 where 56 I: IndexedParallelIterator, 57 ID: Fn() -> U + Send + Sync, 58 F: Fn(U, I::Item) -> U + Send + Sync, 59 U: Send, 60 { 61 type Item = U; 62 drive_unindexed<C>(self, consumer: C) -> C::Result where C: Consumer<U>,63 fn drive_unindexed<C>(self, consumer: C) -> C::Result 64 where 65 C: Consumer<U>, 66 { 67 bridge(self, consumer) 68 } 69 opt_len(&self) -> Option<usize>70 fn opt_len(&self) -> Option<usize> { 71 Some(self.len()) 72 } 73 } 74 75 impl<I, ID, U, F> IndexedParallelIterator for FoldChunks<I, ID, F> 76 where 77 I: IndexedParallelIterator, 78 ID: Fn() -> U + Send + Sync, 79 F: Fn(U, I::Item) -> U + Send + Sync, 80 U: Send, 81 { len(&self) -> usize82 fn len(&self) -> usize { 83 div_round_up(self.base.len(), self.chunk_size) 84 } 85 drive<C>(self, consumer: C) -> C::Result where C: Consumer<Self::Item>,86 fn drive<C>(self, consumer: C) -> C::Result 87 where 88 C: Consumer<Self::Item>, 89 { 90 bridge(self, consumer) 91 } 92 with_producer<CB>(self, callback: CB) -> CB::Output where CB: ProducerCallback<Self::Item>,93 fn with_producer<CB>(self, callback: CB) -> CB::Output 94 where 95 CB: ProducerCallback<Self::Item>, 96 { 97 let len = self.base.len(); 98 return self.base.with_producer(Callback { 99 chunk_size: self.chunk_size, 100 len, 101 identity: self.identity, 102 fold_op: self.fold_op, 103 callback, 104 }); 105 106 struct Callback<CB, ID, F> { 107 chunk_size: usize, 108 len: usize, 109 identity: ID, 110 fold_op: F, 111 callback: CB, 112 } 113 114 impl<T, CB, ID, U, F> ProducerCallback<T> for Callback<CB, ID, F> 115 where 116 CB: ProducerCallback<U>, 117 ID: Fn() -> U + Send + Sync, 118 F: Fn(U, T) -> U + Send + Sync, 119 { 120 type Output = CB::Output; 121 122 fn callback<P>(self, base: P) -> CB::Output 123 where 124 P: Producer<Item = T>, 125 { 126 let identity = &self.identity; 127 let fold_op = &self.fold_op; 128 let fold_iter = move |iter: P::IntoIter| iter.fold(identity(), fold_op); 129 let producer = ChunkProducer::new(self.chunk_size, self.len, base, fold_iter); 130 self.callback.callback(producer) 131 } 132 } 133 } 134 } 135 136 #[cfg(test)] 137 mod test { 138 use super::*; 139 use std::ops::Add; 140 141 #[test] check_fold_chunks()142 fn check_fold_chunks() { 143 let words = "bishbashbosh!" 144 .chars() 145 .collect::<Vec<_>>() 146 .into_par_iter() 147 .fold_chunks(4, String::new, |mut s, c| { 148 s.push(c); 149 s 150 }) 151 .collect::<Vec<_>>(); 152 153 assert_eq!(words, vec!["bish", "bash", "bosh", "!"]); 154 } 155 156 // 'closure' values for tests below id() -> i32157 fn id() -> i32 { 158 0 159 } sum<T, U>(x: T, y: U) -> T where T: Add<U, Output = T>,160 fn sum<T, U>(x: T, y: U) -> T 161 where 162 T: Add<U, Output = T>, 163 { 164 x + y 165 } 166 167 #[test] 168 #[should_panic(expected = "chunk_size must not be zero")] check_fold_chunks_zero_size()169 fn check_fold_chunks_zero_size() { 170 let _: Vec<i32> = vec![1, 2, 3] 171 .into_par_iter() 172 .fold_chunks(0, id, sum) 173 .collect(); 174 } 175 176 #[test] check_fold_chunks_even_size()177 fn check_fold_chunks_even_size() { 178 assert_eq!( 179 vec![1 + 2 + 3, 4 + 5 + 6, 7 + 8 + 9], 180 (1..10) 181 .into_par_iter() 182 .fold_chunks(3, id, sum) 183 .collect::<Vec<i32>>() 184 ); 185 } 186 187 #[test] check_fold_chunks_empty()188 fn check_fold_chunks_empty() { 189 let v: Vec<i32> = vec![]; 190 let expected: Vec<i32> = vec![]; 191 assert_eq!( 192 expected, 193 v.into_par_iter() 194 .fold_chunks(2, id, sum) 195 .collect::<Vec<i32>>() 196 ); 197 } 198 199 #[test] check_fold_chunks_len()200 fn check_fold_chunks_len() { 201 assert_eq!(4, (0..8).into_par_iter().fold_chunks(2, id, sum).len()); 202 assert_eq!(3, (0..9).into_par_iter().fold_chunks(3, id, sum).len()); 203 assert_eq!(3, (0..8).into_par_iter().fold_chunks(3, id, sum).len()); 204 assert_eq!(1, (&[1]).par_iter().fold_chunks(3, id, sum).len()); 205 assert_eq!(0, (0..0).into_par_iter().fold_chunks(3, id, sum).len()); 206 } 207 208 #[test] check_fold_chunks_uneven()209 fn check_fold_chunks_uneven() { 210 let cases: Vec<(Vec<u32>, usize, Vec<u32>)> = vec![ 211 ((0..5).collect(), 3, vec![0 + 1 + 2, 3 + 4]), 212 (vec![1], 5, vec![1]), 213 ((0..4).collect(), 3, vec![0 + 1 + 2, 3]), 214 ]; 215 216 for (i, (v, n, expected)) in cases.into_iter().enumerate() { 217 let mut res: Vec<u32> = vec![]; 218 v.par_iter() 219 .fold_chunks(n, || 0, sum) 220 .collect_into_vec(&mut res); 221 assert_eq!(expected, res, "Case {} failed", i); 222 223 res.truncate(0); 224 v.into_par_iter() 225 .fold_chunks(n, || 0, sum) 226 .rev() 227 .collect_into_vec(&mut res); 228 assert_eq!( 229 expected.into_iter().rev().collect::<Vec<u32>>(), 230 res, 231 "Case {} reversed failed", 232 i 233 ); 234 } 235 } 236 } 237