1 use std::fmt::{self, Debug};
2 
3 use super::chunks::ChunkProducer;
4 use super::plumbing::*;
5 use super::*;
6 use crate::math::div_round_up;
7 
8 /// `FoldChunksWith` is an iterator that groups elements of an underlying iterator and applies a
9 /// function over them, producing a single value for each group.
10 ///
11 /// This struct is created by the [`fold_chunks_with()`] method on [`IndexedParallelIterator`]
12 ///
13 /// [`fold_chunks_with()`]: trait.IndexedParallelIterator.html#method.fold_chunks
14 /// [`IndexedParallelIterator`]: trait.IndexedParallelIterator.html
15 #[must_use = "iterator adaptors are lazy and do nothing unless consumed"]
16 #[derive(Clone)]
17 pub struct FoldChunksWith<I, U, F>
18 where
19     I: IndexedParallelIterator,
20 {
21     base: I,
22     chunk_size: usize,
23     item: U,
24     fold_op: F,
25 }
26 
27 impl<I: IndexedParallelIterator + Debug, U: Debug, F> Debug for FoldChunksWith<I, U, F> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result28     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29         f.debug_struct("Fold")
30             .field("base", &self.base)
31             .field("chunk_size", &self.chunk_size)
32             .field("item", &self.item)
33             .finish()
34     }
35 }
36 
37 impl<I, U, F> FoldChunksWith<I, U, F>
38 where
39     I: IndexedParallelIterator,
40     U: Send + Clone,
41     F: Fn(U, I::Item) -> U + Send + Sync,
42 {
43     /// Creates a new `FoldChunksWith` iterator
new(base: I, chunk_size: usize, item: U, fold_op: F) -> Self44     pub(super) fn new(base: I, chunk_size: usize, item: U, fold_op: F) -> Self {
45         FoldChunksWith {
46             base,
47             chunk_size,
48             item,
49             fold_op,
50         }
51     }
52 }
53 
54 impl<I, U, F> ParallelIterator for FoldChunksWith<I, U, F>
55 where
56     I: IndexedParallelIterator,
57     U: Send + Clone,
58     F: Fn(U, I::Item) -> U + Send + Sync,
59 {
60     type Item = U;
61 
drive_unindexed<C>(self, consumer: C) -> C::Result where C: Consumer<U>,62     fn drive_unindexed<C>(self, consumer: C) -> C::Result
63     where
64         C: Consumer<U>,
65     {
66         bridge(self, consumer)
67     }
68 
opt_len(&self) -> Option<usize>69     fn opt_len(&self) -> Option<usize> {
70         Some(self.len())
71     }
72 }
73 
74 impl<I, U, F> IndexedParallelIterator for FoldChunksWith<I, U, F>
75 where
76     I: IndexedParallelIterator,
77     U: Send + Clone,
78     F: Fn(U, I::Item) -> U + Send + Sync,
79 {
len(&self) -> usize80     fn len(&self) -> usize {
81         div_round_up(self.base.len(), self.chunk_size)
82     }
83 
drive<C>(self, consumer: C) -> C::Result where C: Consumer<Self::Item>,84     fn drive<C>(self, consumer: C) -> C::Result
85     where
86         C: Consumer<Self::Item>,
87     {
88         bridge(self, consumer)
89     }
90 
with_producer<CB>(self, callback: CB) -> CB::Output where CB: ProducerCallback<Self::Item>,91     fn with_producer<CB>(self, callback: CB) -> CB::Output
92     where
93         CB: ProducerCallback<Self::Item>,
94     {
95         let len = self.base.len();
96         return self.base.with_producer(Callback {
97             chunk_size: self.chunk_size,
98             len,
99             item: self.item,
100             fold_op: self.fold_op,
101             callback,
102         });
103 
104         struct Callback<CB, T, F> {
105             chunk_size: usize,
106             len: usize,
107             item: T,
108             fold_op: F,
109             callback: CB,
110         }
111 
112         impl<T, U, F, CB> ProducerCallback<T> for Callback<CB, U, F>
113         where
114             CB: ProducerCallback<U>,
115             U: Send + Clone,
116             F: Fn(U, T) -> U + Send + Sync,
117         {
118             type Output = CB::Output;
119 
120             fn callback<P>(self, base: P) -> CB::Output
121             where
122                 P: Producer<Item = T>,
123             {
124                 let item = self.item;
125                 let fold_op = &self.fold_op;
126                 let fold_iter = move |iter: P::IntoIter| iter.fold(item.clone(), fold_op);
127                 let producer = ChunkProducer::new(self.chunk_size, self.len, base, fold_iter);
128                 self.callback.callback(producer)
129             }
130         }
131     }
132 }
133 
134 #[cfg(test)]
135 mod test {
136     use super::*;
137     use std::ops::Add;
138 
139     #[test]
check_fold_chunks_with()140     fn check_fold_chunks_with() {
141         let words = "bishbashbosh!"
142             .chars()
143             .collect::<Vec<_>>()
144             .into_par_iter()
145             .fold_chunks_with(4, String::new(), |mut s, c| {
146                 s.push(c);
147                 s
148             })
149             .collect::<Vec<_>>();
150 
151         assert_eq!(words, vec!["bish", "bash", "bosh", "!"]);
152     }
153 
154     // 'closure' value for tests below
sum<T, U>(x: T, y: U) -> T where T: Add<U, Output = T>,155     fn sum<T, U>(x: T, y: U) -> T
156     where
157         T: Add<U, Output = T>,
158     {
159         x + y
160     }
161 
162     #[test]
163     #[should_panic(expected = "chunk_size must not be zero")]
check_fold_chunks_zero_size()164     fn check_fold_chunks_zero_size() {
165         let _: Vec<i32> = vec![1, 2, 3]
166             .into_par_iter()
167             .fold_chunks_with(0, 0, sum)
168             .collect();
169     }
170 
171     #[test]
check_fold_chunks_even_size()172     fn check_fold_chunks_even_size() {
173         assert_eq!(
174             vec![1 + 2 + 3, 4 + 5 + 6, 7 + 8 + 9],
175             (1..10)
176                 .into_par_iter()
177                 .fold_chunks_with(3, 0, sum)
178                 .collect::<Vec<i32>>()
179         );
180     }
181 
182     #[test]
check_fold_chunks_with_empty()183     fn check_fold_chunks_with_empty() {
184         let v: Vec<i32> = vec![];
185         let expected: Vec<i32> = vec![];
186         assert_eq!(
187             expected,
188             v.into_par_iter()
189                 .fold_chunks_with(2, 0, sum)
190                 .collect::<Vec<i32>>()
191         );
192     }
193 
194     #[test]
check_fold_chunks_len()195     fn check_fold_chunks_len() {
196         assert_eq!(4, (0..8).into_par_iter().fold_chunks_with(2, 0, sum).len());
197         assert_eq!(3, (0..9).into_par_iter().fold_chunks_with(3, 0, sum).len());
198         assert_eq!(3, (0..8).into_par_iter().fold_chunks_with(3, 0, sum).len());
199         assert_eq!(1, (&[1]).par_iter().fold_chunks_with(3, 0, sum).len());
200         assert_eq!(0, (0..0).into_par_iter().fold_chunks_with(3, 0, sum).len());
201     }
202 
203     #[test]
check_fold_chunks_uneven()204     fn check_fold_chunks_uneven() {
205         let cases: Vec<(Vec<u32>, usize, Vec<u32>)> = vec![
206             ((0..5).collect(), 3, vec![0 + 1 + 2, 3 + 4]),
207             (vec![1], 5, vec![1]),
208             ((0..4).collect(), 3, vec![0 + 1 + 2, 3]),
209         ];
210 
211         for (i, (v, n, expected)) in cases.into_iter().enumerate() {
212             let mut res: Vec<u32> = vec![];
213             v.par_iter()
214                 .fold_chunks_with(n, 0, sum)
215                 .collect_into_vec(&mut res);
216             assert_eq!(expected, res, "Case {} failed", i);
217 
218             res.truncate(0);
219             v.into_par_iter()
220                 .fold_chunks_with(n, 0, sum)
221                 .rev()
222                 .collect_into_vec(&mut res);
223             assert_eq!(
224                 expected.into_iter().rev().collect::<Vec<u32>>(),
225                 res,
226                 "Case {} reversed failed",
227                 i
228             );
229         }
230     }
231 }
232