1 use core::iter::FlatMap;
2 
3 /// A specialized version of `core::iter::FlatMap` for mapping over exact-sized
4 /// iterators with a function that returns an array.
5 ///
6 /// `ArrayFlatMap` differs from `FlatMap` in that `ArrayFlatMap` implements
7 /// `ExactSizeIterator`. Since the result of `F` always has `LEN` elements, if
8 /// `I` is an exact-sized iterator of length `inner_len` then we know the
9 /// length of the flat-mapped result is `inner_len * LEN`. (The constructor
10 /// verifies that this multiplication doesn't overflow `usize`.)
11 #[derive(Clone)]
12 pub struct ArrayFlatMap<I, Item, F, const LEN: usize> {
13     inner: FlatMap<I, [Item; LEN], F>,
14     remaining: usize,
15 }
16 
17 impl<I, Item, F, const LEN: usize> ArrayFlatMap<I, Item, F, LEN>
18 where
19     I: ExactSizeIterator,
20     F: FnMut(I::Item) -> [Item; LEN],
21 {
22     /// Constructs an `ArrayFlatMap` wrapping the given iterator, using the
23     /// given function
new(inner: I, f: F) -> Option<Self>24     pub fn new(inner: I, f: F) -> Option<Self> {
25         let remaining = inner.len().checked_mul(LEN)?;
26         let inner = inner.flat_map(f);
27         Some(Self { inner, remaining })
28     }
29 }
30 
31 impl<I, Item, F, const LEN: usize> Iterator for ArrayFlatMap<I, Item, F, LEN>
32 where
33     I: Iterator,
34     F: FnMut(I::Item) -> [Item; LEN],
35 {
36     type Item = Item;
37 
next(&mut self) -> Option<Self::Item>38     fn next(&mut self) -> Option<Self::Item> {
39         let result = self.inner.next();
40         if result.is_some() {
41             self.remaining -= 1;
42         }
43         result
44     }
45 
46     /// Required for implementing `ExactSizeIterator`.
size_hint(&self) -> (usize, Option<usize>)47     fn size_hint(&self) -> (usize, Option<usize>) {
48         (self.remaining, Some(self.remaining))
49     }
50 }
51 
52 impl<I, Item, F, const LEN: usize> ExactSizeIterator for ArrayFlatMap<I, Item, F, LEN>
53 where
54     I: Iterator,
55     F: FnMut(I::Item) -> [Item; LEN],
56 {
57 }
58 
59 #[cfg(test)]
60 mod tests {
61     use super::*;
62 
63     #[test]
test_array_flat_map()64     fn test_array_flat_map() {
65         static TEST_CASES: &[(&[u16], fn(u16) -> [u8; 2], &[u8])] = &[
66             // Empty input
67             (&[], u16::to_be_bytes, &[]),
68             // Non-empty input.
69             (
70                 &[0x0102, 0x0304, 0x0506],
71                 u16::to_be_bytes,
72                 &[1, 2, 3, 4, 5, 6],
73             ),
74             // Test with a different mapping function.
75             (
76                 &[0x0102, 0x0304, 0x0506],
77                 u16::to_le_bytes,
78                 &[2, 1, 4, 3, 6, 5],
79             ),
80         ];
81         TEST_CASES.iter().copied().for_each(|(input, f, expected)| {
82             let mapped = ArrayFlatMap::new(input.iter().copied(), f).unwrap();
83             super::super::test::assert_iterator(mapped, expected);
84         });
85     }
86 
87     // Does ArrayFlatMap::new() handle overflow correctly?
88     #[test]
test_array_flat_map_len_overflow()89     fn test_array_flat_map_len_overflow() {
90         struct DownwardCounter {
91             remaining: usize,
92         }
93         impl Iterator for DownwardCounter {
94             type Item = usize;
95 
96             fn next(&mut self) -> Option<Self::Item> {
97                 if self.remaining > 0 {
98                     let result = self.remaining;
99                     self.remaining -= 1;
100                     Some(result)
101                 } else {
102                     None
103                 }
104             }
105 
106             fn size_hint(&self) -> (usize, Option<usize>) {
107                 (self.remaining, Some(self.remaining))
108             }
109         }
110         impl ExactSizeIterator for DownwardCounter {}
111 
112         const MAX: usize = usize::MAX / core::mem::size_of::<usize>();
113 
114         static TEST_CASES: &[(usize, bool)] = &[(MAX, true), (MAX + 1, false)];
115         TEST_CASES.iter().copied().for_each(|(input_len, is_some)| {
116             let inner = DownwardCounter {
117                 remaining: input_len,
118             };
119             let mapped = ArrayFlatMap::new(inner, usize::to_be_bytes);
120             assert_eq!(mapped.is_some(), is_some);
121             if let Some(mapped) = mapped {
122                 assert_eq!(mapped.len(), input_len * core::mem::size_of::<usize>());
123             }
124         });
125     }
126 }
127