1 use std::fmt::{self, Debug};
2 
3 use super::chunks::ChunkProducer;
4 use super::plumbing::*;
5 use super::*;
6 use crate::math::div_round_up;
7 
8 /// `FoldChunks` is an iterator that groups elements of an underlying iterator and applies a
9 /// function over them, producing a single value for each group.
10 ///
11 /// This struct is created by the [`fold_chunks()`] method on [`IndexedParallelIterator`]
12 ///
13 /// [`fold_chunks()`]: trait.IndexedParallelIterator.html#method.fold_chunks
14 /// [`IndexedParallelIterator`]: trait.IndexedParallelIterator.html
15 #[must_use = "iterator adaptors are lazy and do nothing unless consumed"]
16 #[derive(Clone)]
17 pub struct FoldChunks<I, ID, F>
18 where
19     I: IndexedParallelIterator,
20 {
21     base: I,
22     chunk_size: usize,
23     fold_op: F,
24     identity: ID,
25 }
26 
27 impl<I: IndexedParallelIterator + Debug, ID, F> Debug for FoldChunks<I, ID, F> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result28     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29         f.debug_struct("Fold")
30             .field("base", &self.base)
31             .field("chunk_size", &self.chunk_size)
32             .finish()
33     }
34 }
35 
36 impl<I, ID, U, F> FoldChunks<I, ID, F>
37 where
38     I: IndexedParallelIterator,
39     ID: Fn() -> U + Send + Sync,
40     F: Fn(U, I::Item) -> U + Send + Sync,
41     U: Send,
42 {
43     /// Creates a new `FoldChunks` iterator
new(base: I, chunk_size: usize, identity: ID, fold_op: F) -> Self44     pub(super) fn new(base: I, chunk_size: usize, identity: ID, fold_op: F) -> Self {
45         FoldChunks {
46             base,
47             chunk_size,
48             identity,
49             fold_op,
50         }
51     }
52 }
53 
54 impl<I, ID, U, F> ParallelIterator for FoldChunks<I, ID, F>
55 where
56     I: IndexedParallelIterator,
57     ID: Fn() -> U + Send + Sync,
58     F: Fn(U, I::Item) -> U + Send + Sync,
59     U: Send,
60 {
61     type Item = U;
62 
drive_unindexed<C>(self, consumer: C) -> C::Result where C: Consumer<U>,63     fn drive_unindexed<C>(self, consumer: C) -> C::Result
64     where
65         C: Consumer<U>,
66     {
67         bridge(self, consumer)
68     }
69 
opt_len(&self) -> Option<usize>70     fn opt_len(&self) -> Option<usize> {
71         Some(self.len())
72     }
73 }
74 
75 impl<I, ID, U, F> IndexedParallelIterator for FoldChunks<I, ID, F>
76 where
77     I: IndexedParallelIterator,
78     ID: Fn() -> U + Send + Sync,
79     F: Fn(U, I::Item) -> U + Send + Sync,
80     U: Send,
81 {
len(&self) -> usize82     fn len(&self) -> usize {
83         div_round_up(self.base.len(), self.chunk_size)
84     }
85 
drive<C>(self, consumer: C) -> C::Result where C: Consumer<Self::Item>,86     fn drive<C>(self, consumer: C) -> C::Result
87     where
88         C: Consumer<Self::Item>,
89     {
90         bridge(self, consumer)
91     }
92 
with_producer<CB>(self, callback: CB) -> CB::Output where CB: ProducerCallback<Self::Item>,93     fn with_producer<CB>(self, callback: CB) -> CB::Output
94     where
95         CB: ProducerCallback<Self::Item>,
96     {
97         let len = self.base.len();
98         return self.base.with_producer(Callback {
99             chunk_size: self.chunk_size,
100             len,
101             identity: self.identity,
102             fold_op: self.fold_op,
103             callback,
104         });
105 
106         struct Callback<CB, ID, F> {
107             chunk_size: usize,
108             len: usize,
109             identity: ID,
110             fold_op: F,
111             callback: CB,
112         }
113 
114         impl<T, CB, ID, U, F> ProducerCallback<T> for Callback<CB, ID, F>
115         where
116             CB: ProducerCallback<U>,
117             ID: Fn() -> U + Send + Sync,
118             F: Fn(U, T) -> U + Send + Sync,
119         {
120             type Output = CB::Output;
121 
122             fn callback<P>(self, base: P) -> CB::Output
123             where
124                 P: Producer<Item = T>,
125             {
126                 let identity = &self.identity;
127                 let fold_op = &self.fold_op;
128                 let fold_iter = move |iter: P::IntoIter| iter.fold(identity(), fold_op);
129                 let producer = ChunkProducer::new(self.chunk_size, self.len, base, fold_iter);
130                 self.callback.callback(producer)
131             }
132         }
133     }
134 }
135 
136 #[cfg(test)]
137 mod test {
138     use super::*;
139     use std::ops::Add;
140 
141     #[test]
check_fold_chunks()142     fn check_fold_chunks() {
143         let words = "bishbashbosh!"
144             .chars()
145             .collect::<Vec<_>>()
146             .into_par_iter()
147             .fold_chunks(4, String::new, |mut s, c| {
148                 s.push(c);
149                 s
150             })
151             .collect::<Vec<_>>();
152 
153         assert_eq!(words, vec!["bish", "bash", "bosh", "!"]);
154     }
155 
156     // 'closure' values for tests below
id() -> i32157     fn id() -> i32 {
158         0
159     }
sum<T, U>(x: T, y: U) -> T where T: Add<U, Output = T>,160     fn sum<T, U>(x: T, y: U) -> T
161     where
162         T: Add<U, Output = T>,
163     {
164         x + y
165     }
166 
167     #[test]
168     #[should_panic(expected = "chunk_size must not be zero")]
check_fold_chunks_zero_size()169     fn check_fold_chunks_zero_size() {
170         let _: Vec<i32> = vec![1, 2, 3]
171             .into_par_iter()
172             .fold_chunks(0, id, sum)
173             .collect();
174     }
175 
176     #[test]
check_fold_chunks_even_size()177     fn check_fold_chunks_even_size() {
178         assert_eq!(
179             vec![1 + 2 + 3, 4 + 5 + 6, 7 + 8 + 9],
180             (1..10)
181                 .into_par_iter()
182                 .fold_chunks(3, id, sum)
183                 .collect::<Vec<i32>>()
184         );
185     }
186 
187     #[test]
check_fold_chunks_empty()188     fn check_fold_chunks_empty() {
189         let v: Vec<i32> = vec![];
190         let expected: Vec<i32> = vec![];
191         assert_eq!(
192             expected,
193             v.into_par_iter()
194                 .fold_chunks(2, id, sum)
195                 .collect::<Vec<i32>>()
196         );
197     }
198 
199     #[test]
check_fold_chunks_len()200     fn check_fold_chunks_len() {
201         assert_eq!(4, (0..8).into_par_iter().fold_chunks(2, id, sum).len());
202         assert_eq!(3, (0..9).into_par_iter().fold_chunks(3, id, sum).len());
203         assert_eq!(3, (0..8).into_par_iter().fold_chunks(3, id, sum).len());
204         assert_eq!(1, (&[1]).par_iter().fold_chunks(3, id, sum).len());
205         assert_eq!(0, (0..0).into_par_iter().fold_chunks(3, id, sum).len());
206     }
207 
208     #[test]
check_fold_chunks_uneven()209     fn check_fold_chunks_uneven() {
210         let cases: Vec<(Vec<u32>, usize, Vec<u32>)> = vec![
211             ((0..5).collect(), 3, vec![0 + 1 + 2, 3 + 4]),
212             (vec![1], 5, vec![1]),
213             ((0..4).collect(), 3, vec![0 + 1 + 2, 3]),
214         ];
215 
216         for (i, (v, n, expected)) in cases.into_iter().enumerate() {
217             let mut res: Vec<u32> = vec![];
218             v.par_iter()
219                 .fold_chunks(n, || 0, sum)
220                 .collect_into_vec(&mut res);
221             assert_eq!(expected, res, "Case {} failed", i);
222 
223             res.truncate(0);
224             v.into_par_iter()
225                 .fold_chunks(n, || 0, sum)
226                 .rev()
227                 .collect_into_vec(&mut res);
228             assert_eq!(
229                 expected.into_iter().rev().collect::<Vec<u32>>(),
230                 res,
231                 "Case {} reversed failed",
232                 i
233             );
234         }
235     }
236 }
237