1 use std::fmt::{self, Debug}; 2 3 use super::chunks::ChunkProducer; 4 use super::plumbing::*; 5 use super::*; 6 use crate::math::div_round_up; 7 8 /// `FoldChunksWith` is an iterator that groups elements of an underlying iterator and applies a 9 /// function over them, producing a single value for each group. 10 /// 11 /// This struct is created by the [`fold_chunks_with()`] method on [`IndexedParallelIterator`] 12 /// 13 /// [`fold_chunks_with()`]: trait.IndexedParallelIterator.html#method.fold_chunks 14 /// [`IndexedParallelIterator`]: trait.IndexedParallelIterator.html 15 #[must_use = "iterator adaptors are lazy and do nothing unless consumed"] 16 #[derive(Clone)] 17 pub struct FoldChunksWith<I, U, F> 18 where 19 I: IndexedParallelIterator, 20 { 21 base: I, 22 chunk_size: usize, 23 item: U, 24 fold_op: F, 25 } 26 27 impl<I: IndexedParallelIterator + Debug, U: Debug, F> Debug for FoldChunksWith<I, U, F> { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result28 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 29 f.debug_struct("Fold") 30 .field("base", &self.base) 31 .field("chunk_size", &self.chunk_size) 32 .field("item", &self.item) 33 .finish() 34 } 35 } 36 37 impl<I, U, F> FoldChunksWith<I, U, F> 38 where 39 I: IndexedParallelIterator, 40 U: Send + Clone, 41 F: Fn(U, I::Item) -> U + Send + Sync, 42 { 43 /// Creates a new `FoldChunksWith` iterator new(base: I, chunk_size: usize, item: U, fold_op: F) -> Self44 pub(super) fn new(base: I, chunk_size: usize, item: U, fold_op: F) -> Self { 45 FoldChunksWith { 46 base, 47 chunk_size, 48 item, 49 fold_op, 50 } 51 } 52 } 53 54 impl<I, U, F> ParallelIterator for FoldChunksWith<I, U, F> 55 where 56 I: IndexedParallelIterator, 57 U: Send + Clone, 58 F: Fn(U, I::Item) -> U + Send + Sync, 59 { 60 type Item = U; 61 drive_unindexed<C>(self, consumer: C) -> C::Result where C: Consumer<U>,62 fn drive_unindexed<C>(self, consumer: C) -> C::Result 63 where 64 C: Consumer<U>, 65 { 66 bridge(self, consumer) 67 } 68 opt_len(&self) -> Option<usize>69 fn opt_len(&self) -> Option<usize> { 70 Some(self.len()) 71 } 72 } 73 74 impl<I, U, F> IndexedParallelIterator for FoldChunksWith<I, U, F> 75 where 76 I: IndexedParallelIterator, 77 U: Send + Clone, 78 F: Fn(U, I::Item) -> U + Send + Sync, 79 { len(&self) -> usize80 fn len(&self) -> usize { 81 div_round_up(self.base.len(), self.chunk_size) 82 } 83 drive<C>(self, consumer: C) -> C::Result where C: Consumer<Self::Item>,84 fn drive<C>(self, consumer: C) -> C::Result 85 where 86 C: Consumer<Self::Item>, 87 { 88 bridge(self, consumer) 89 } 90 with_producer<CB>(self, callback: CB) -> CB::Output where CB: ProducerCallback<Self::Item>,91 fn with_producer<CB>(self, callback: CB) -> CB::Output 92 where 93 CB: ProducerCallback<Self::Item>, 94 { 95 let len = self.base.len(); 96 return self.base.with_producer(Callback { 97 chunk_size: self.chunk_size, 98 len, 99 item: self.item, 100 fold_op: self.fold_op, 101 callback, 102 }); 103 104 struct Callback<CB, T, F> { 105 chunk_size: usize, 106 len: usize, 107 item: T, 108 fold_op: F, 109 callback: CB, 110 } 111 112 impl<T, U, F, CB> ProducerCallback<T> for Callback<CB, U, F> 113 where 114 CB: ProducerCallback<U>, 115 U: Send + Clone, 116 F: Fn(U, T) -> U + Send + Sync, 117 { 118 type Output = CB::Output; 119 120 fn callback<P>(self, base: P) -> CB::Output 121 where 122 P: Producer<Item = T>, 123 { 124 let item = self.item; 125 let fold_op = &self.fold_op; 126 let fold_iter = move |iter: P::IntoIter| iter.fold(item.clone(), fold_op); 127 let producer = ChunkProducer::new(self.chunk_size, self.len, base, fold_iter); 128 self.callback.callback(producer) 129 } 130 } 131 } 132 } 133 134 #[cfg(test)] 135 mod test { 136 use super::*; 137 use std::ops::Add; 138 139 #[test] check_fold_chunks_with()140 fn check_fold_chunks_with() { 141 let words = "bishbashbosh!" 142 .chars() 143 .collect::<Vec<_>>() 144 .into_par_iter() 145 .fold_chunks_with(4, String::new(), |mut s, c| { 146 s.push(c); 147 s 148 }) 149 .collect::<Vec<_>>(); 150 151 assert_eq!(words, vec!["bish", "bash", "bosh", "!"]); 152 } 153 154 // 'closure' value for tests below sum<T, U>(x: T, y: U) -> T where T: Add<U, Output = T>,155 fn sum<T, U>(x: T, y: U) -> T 156 where 157 T: Add<U, Output = T>, 158 { 159 x + y 160 } 161 162 #[test] 163 #[should_panic(expected = "chunk_size must not be zero")] check_fold_chunks_zero_size()164 fn check_fold_chunks_zero_size() { 165 let _: Vec<i32> = vec![1, 2, 3] 166 .into_par_iter() 167 .fold_chunks_with(0, 0, sum) 168 .collect(); 169 } 170 171 #[test] check_fold_chunks_even_size()172 fn check_fold_chunks_even_size() { 173 assert_eq!( 174 vec![1 + 2 + 3, 4 + 5 + 6, 7 + 8 + 9], 175 (1..10) 176 .into_par_iter() 177 .fold_chunks_with(3, 0, sum) 178 .collect::<Vec<i32>>() 179 ); 180 } 181 182 #[test] check_fold_chunks_with_empty()183 fn check_fold_chunks_with_empty() { 184 let v: Vec<i32> = vec![]; 185 let expected: Vec<i32> = vec![]; 186 assert_eq!( 187 expected, 188 v.into_par_iter() 189 .fold_chunks_with(2, 0, sum) 190 .collect::<Vec<i32>>() 191 ); 192 } 193 194 #[test] check_fold_chunks_len()195 fn check_fold_chunks_len() { 196 assert_eq!(4, (0..8).into_par_iter().fold_chunks_with(2, 0, sum).len()); 197 assert_eq!(3, (0..9).into_par_iter().fold_chunks_with(3, 0, sum).len()); 198 assert_eq!(3, (0..8).into_par_iter().fold_chunks_with(3, 0, sum).len()); 199 assert_eq!(1, (&[1]).par_iter().fold_chunks_with(3, 0, sum).len()); 200 assert_eq!(0, (0..0).into_par_iter().fold_chunks_with(3, 0, sum).len()); 201 } 202 203 #[test] check_fold_chunks_uneven()204 fn check_fold_chunks_uneven() { 205 let cases: Vec<(Vec<u32>, usize, Vec<u32>)> = vec![ 206 ((0..5).collect(), 3, vec![0 + 1 + 2, 3 + 4]), 207 (vec![1], 5, vec![1]), 208 ((0..4).collect(), 3, vec![0 + 1 + 2, 3]), 209 ]; 210 211 for (i, (v, n, expected)) in cases.into_iter().enumerate() { 212 let mut res: Vec<u32> = vec![]; 213 v.par_iter() 214 .fold_chunks_with(n, 0, sum) 215 .collect_into_vec(&mut res); 216 assert_eq!(expected, res, "Case {} failed", i); 217 218 res.truncate(0); 219 v.into_par_iter() 220 .fold_chunks_with(n, 0, sum) 221 .rev() 222 .collect_into_vec(&mut res); 223 assert_eq!( 224 expected.into_iter().rev().collect::<Vec<u32>>(), 225 res, 226 "Case {} reversed failed", 227 i 228 ); 229 } 230 } 231 } 232