1 use crate::{
2     engine::{general_purpose::INVALID_VALUE, DecodeEstimate, DecodeMetadata, DecodePaddingMode},
3     DecodeError, DecodeSliceError, PAD_BYTE,
4 };
5 
6 #[doc(hidden)]
7 pub struct GeneralPurposeEstimate {
8     /// input len % 4
9     rem: usize,
10     conservative_decoded_len: usize,
11 }
12 
13 impl GeneralPurposeEstimate {
new(encoded_len: usize) -> Self14     pub(crate) fn new(encoded_len: usize) -> Self {
15         let rem = encoded_len % 4;
16         Self {
17             rem,
18             conservative_decoded_len: (encoded_len / 4 + (rem > 0) as usize) * 3,
19         }
20     }
21 }
22 
23 impl DecodeEstimate for GeneralPurposeEstimate {
decoded_len_estimate(&self) -> usize24     fn decoded_len_estimate(&self) -> usize {
25         self.conservative_decoded_len
26     }
27 }
28 
29 /// Helper to avoid duplicating num_chunks calculation, which is costly on short inputs.
30 /// Returns the decode metadata, or an error.
31 // We're on the fragile edge of compiler heuristics here. If this is not inlined, slow. If this is
32 // inlined(always), a different slow. plain ol' inline makes the benchmarks happiest at the moment,
33 // but this is fragile and the best setting changes with only minor code modifications.
34 #[inline]
decode_helper( input: &[u8], estimate: GeneralPurposeEstimate, output: &mut [u8], decode_table: &[u8; 256], decode_allow_trailing_bits: bool, padding_mode: DecodePaddingMode, ) -> Result<DecodeMetadata, DecodeSliceError>35 pub(crate) fn decode_helper(
36     input: &[u8],
37     estimate: GeneralPurposeEstimate,
38     output: &mut [u8],
39     decode_table: &[u8; 256],
40     decode_allow_trailing_bits: bool,
41     padding_mode: DecodePaddingMode,
42 ) -> Result<DecodeMetadata, DecodeSliceError> {
43     let input_complete_nonterminal_quads_len =
44         complete_quads_len(input, estimate.rem, output.len(), decode_table)?;
45 
46     const UNROLLED_INPUT_CHUNK_SIZE: usize = 32;
47     const UNROLLED_OUTPUT_CHUNK_SIZE: usize = UNROLLED_INPUT_CHUNK_SIZE / 4 * 3;
48 
49     let input_complete_quads_after_unrolled_chunks_len =
50         input_complete_nonterminal_quads_len % UNROLLED_INPUT_CHUNK_SIZE;
51 
52     let input_unrolled_loop_len =
53         input_complete_nonterminal_quads_len - input_complete_quads_after_unrolled_chunks_len;
54 
55     // chunks of 32 bytes
56     for (chunk_index, chunk) in input[..input_unrolled_loop_len]
57         .chunks_exact(UNROLLED_INPUT_CHUNK_SIZE)
58         .enumerate()
59     {
60         let input_index = chunk_index * UNROLLED_INPUT_CHUNK_SIZE;
61         let chunk_output = &mut output[chunk_index * UNROLLED_OUTPUT_CHUNK_SIZE
62             ..(chunk_index + 1) * UNROLLED_OUTPUT_CHUNK_SIZE];
63 
64         decode_chunk_8(
65             &chunk[0..8],
66             input_index,
67             decode_table,
68             &mut chunk_output[0..6],
69         )?;
70         decode_chunk_8(
71             &chunk[8..16],
72             input_index + 8,
73             decode_table,
74             &mut chunk_output[6..12],
75         )?;
76         decode_chunk_8(
77             &chunk[16..24],
78             input_index + 16,
79             decode_table,
80             &mut chunk_output[12..18],
81         )?;
82         decode_chunk_8(
83             &chunk[24..32],
84             input_index + 24,
85             decode_table,
86             &mut chunk_output[18..24],
87         )?;
88     }
89 
90     // remaining quads, except for the last possibly partial one, as it may have padding
91     let output_unrolled_loop_len = input_unrolled_loop_len / 4 * 3;
92     let output_complete_quad_len = input_complete_nonterminal_quads_len / 4 * 3;
93     {
94         let output_after_unroll = &mut output[output_unrolled_loop_len..output_complete_quad_len];
95 
96         for (chunk_index, chunk) in input
97             [input_unrolled_loop_len..input_complete_nonterminal_quads_len]
98             .chunks_exact(4)
99             .enumerate()
100         {
101             let chunk_output = &mut output_after_unroll[chunk_index * 3..chunk_index * 3 + 3];
102 
103             decode_chunk_4(
104                 chunk,
105                 input_unrolled_loop_len + chunk_index * 4,
106                 decode_table,
107                 chunk_output,
108             )?;
109         }
110     }
111 
112     super::decode_suffix::decode_suffix(
113         input,
114         input_complete_nonterminal_quads_len,
115         output,
116         output_complete_quad_len,
117         decode_table,
118         decode_allow_trailing_bits,
119         padding_mode,
120     )
121 }
122 
123 /// Returns the length of complete quads, except for the last one, even if it is complete.
124 ///
125 /// Returns an error if the output len is not big enough for decoding those complete quads, or if
126 /// the input % 4 == 1, and that last byte is an invalid value other than a pad byte.
127 ///
128 /// - `input` is the base64 input
129 /// - `input_len_rem` is input len % 4
130 /// - `output_len` is the length of the output slice
complete_quads_len( input: &[u8], input_len_rem: usize, output_len: usize, decode_table: &[u8; 256], ) -> Result<usize, DecodeSliceError>131 pub(crate) fn complete_quads_len(
132     input: &[u8],
133     input_len_rem: usize,
134     output_len: usize,
135     decode_table: &[u8; 256],
136 ) -> Result<usize, DecodeSliceError> {
137     debug_assert!(input.len() % 4 == input_len_rem);
138 
139     // detect a trailing invalid byte, like a newline, as a user convenience
140     if input_len_rem == 1 {
141         let last_byte = input[input.len() - 1];
142         // exclude pad bytes; might be part of padding that extends from earlier in the input
143         if last_byte != PAD_BYTE && decode_table[usize::from(last_byte)] == INVALID_VALUE {
144             return Err(DecodeError::InvalidByte(input.len() - 1, last_byte).into());
145         }
146     };
147 
148     // skip last quad, even if it's complete, as it may have padding
149     let input_complete_nonterminal_quads_len = input
150         .len()
151         .saturating_sub(input_len_rem)
152         // if rem was 0, subtract 4 to avoid padding
153         .saturating_sub((input_len_rem == 0) as usize * 4);
154     debug_assert!(
155         input.is_empty() || (1..=4).contains(&(input.len() - input_complete_nonterminal_quads_len))
156     );
157 
158     // check that everything except the last quad handled by decode_suffix will fit
159     if output_len < input_complete_nonterminal_quads_len / 4 * 3 {
160         return Err(DecodeSliceError::OutputSliceTooSmall);
161     };
162     Ok(input_complete_nonterminal_quads_len)
163 }
164 
165 /// Decode 8 bytes of input into 6 bytes of output.
166 ///
167 /// `input` is the 8 bytes to decode.
168 /// `index_at_start_of_input` is the offset in the overall input (used for reporting errors
169 /// accurately)
170 /// `decode_table` is the lookup table for the particular base64 alphabet.
171 /// `output` will have its first 6 bytes overwritten
172 // yes, really inline (worth 30-50% speedup)
173 #[inline(always)]
decode_chunk_8( input: &[u8], index_at_start_of_input: usize, decode_table: &[u8; 256], output: &mut [u8], ) -> Result<(), DecodeError>174 fn decode_chunk_8(
175     input: &[u8],
176     index_at_start_of_input: usize,
177     decode_table: &[u8; 256],
178     output: &mut [u8],
179 ) -> Result<(), DecodeError> {
180     let morsel = decode_table[usize::from(input[0])];
181     if morsel == INVALID_VALUE {
182         return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0]));
183     }
184     let mut accum = u64::from(morsel) << 58;
185 
186     let morsel = decode_table[usize::from(input[1])];
187     if morsel == INVALID_VALUE {
188         return Err(DecodeError::InvalidByte(
189             index_at_start_of_input + 1,
190             input[1],
191         ));
192     }
193     accum |= u64::from(morsel) << 52;
194 
195     let morsel = decode_table[usize::from(input[2])];
196     if morsel == INVALID_VALUE {
197         return Err(DecodeError::InvalidByte(
198             index_at_start_of_input + 2,
199             input[2],
200         ));
201     }
202     accum |= u64::from(morsel) << 46;
203 
204     let morsel = decode_table[usize::from(input[3])];
205     if morsel == INVALID_VALUE {
206         return Err(DecodeError::InvalidByte(
207             index_at_start_of_input + 3,
208             input[3],
209         ));
210     }
211     accum |= u64::from(morsel) << 40;
212 
213     let morsel = decode_table[usize::from(input[4])];
214     if morsel == INVALID_VALUE {
215         return Err(DecodeError::InvalidByte(
216             index_at_start_of_input + 4,
217             input[4],
218         ));
219     }
220     accum |= u64::from(morsel) << 34;
221 
222     let morsel = decode_table[usize::from(input[5])];
223     if morsel == INVALID_VALUE {
224         return Err(DecodeError::InvalidByte(
225             index_at_start_of_input + 5,
226             input[5],
227         ));
228     }
229     accum |= u64::from(morsel) << 28;
230 
231     let morsel = decode_table[usize::from(input[6])];
232     if morsel == INVALID_VALUE {
233         return Err(DecodeError::InvalidByte(
234             index_at_start_of_input + 6,
235             input[6],
236         ));
237     }
238     accum |= u64::from(morsel) << 22;
239 
240     let morsel = decode_table[usize::from(input[7])];
241     if morsel == INVALID_VALUE {
242         return Err(DecodeError::InvalidByte(
243             index_at_start_of_input + 7,
244             input[7],
245         ));
246     }
247     accum |= u64::from(morsel) << 16;
248 
249     output[..6].copy_from_slice(&accum.to_be_bytes()[..6]);
250 
251     Ok(())
252 }
253 
254 /// Like [decode_chunk_8] but for 4 bytes of input and 3 bytes of output.
255 #[inline(always)]
decode_chunk_4( input: &[u8], index_at_start_of_input: usize, decode_table: &[u8; 256], output: &mut [u8], ) -> Result<(), DecodeError>256 fn decode_chunk_4(
257     input: &[u8],
258     index_at_start_of_input: usize,
259     decode_table: &[u8; 256],
260     output: &mut [u8],
261 ) -> Result<(), DecodeError> {
262     let morsel = decode_table[usize::from(input[0])];
263     if morsel == INVALID_VALUE {
264         return Err(DecodeError::InvalidByte(index_at_start_of_input, input[0]));
265     }
266     let mut accum = u32::from(morsel) << 26;
267 
268     let morsel = decode_table[usize::from(input[1])];
269     if morsel == INVALID_VALUE {
270         return Err(DecodeError::InvalidByte(
271             index_at_start_of_input + 1,
272             input[1],
273         ));
274     }
275     accum |= u32::from(morsel) << 20;
276 
277     let morsel = decode_table[usize::from(input[2])];
278     if morsel == INVALID_VALUE {
279         return Err(DecodeError::InvalidByte(
280             index_at_start_of_input + 2,
281             input[2],
282         ));
283     }
284     accum |= u32::from(morsel) << 14;
285 
286     let morsel = decode_table[usize::from(input[3])];
287     if morsel == INVALID_VALUE {
288         return Err(DecodeError::InvalidByte(
289             index_at_start_of_input + 3,
290             input[3],
291         ));
292     }
293     accum |= u32::from(morsel) << 8;
294 
295     output[..3].copy_from_slice(&accum.to_be_bytes()[..3]);
296 
297     Ok(())
298 }
299 
300 #[cfg(test)]
301 mod tests {
302     use super::*;
303 
304     use crate::engine::general_purpose::STANDARD;
305 
306     #[test]
decode_chunk_8_writes_only_6_bytes()307     fn decode_chunk_8_writes_only_6_bytes() {
308         let input = b"Zm9vYmFy"; // "foobar"
309         let mut output = [0_u8, 1, 2, 3, 4, 5, 6, 7];
310 
311         decode_chunk_8(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap();
312         assert_eq!(&vec![b'f', b'o', b'o', b'b', b'a', b'r', 6, 7], &output);
313     }
314 
315     #[test]
decode_chunk_4_writes_only_3_bytes()316     fn decode_chunk_4_writes_only_3_bytes() {
317         let input = b"Zm9v"; // "foobar"
318         let mut output = [0_u8, 1, 2, 3];
319 
320         decode_chunk_4(&input[..], 0, &STANDARD.decode_table, &mut output).unwrap();
321         assert_eq!(&vec![b'f', b'o', b'o', 3], &output);
322     }
323 
324     #[test]
estimate_short_lengths()325     fn estimate_short_lengths() {
326         for (range, decoded_len_estimate) in [
327             (0..=0, 0),
328             (1..=4, 3),
329             (5..=8, 6),
330             (9..=12, 9),
331             (13..=16, 12),
332             (17..=20, 15),
333         ] {
334             for encoded_len in range {
335                 let estimate = GeneralPurposeEstimate::new(encoded_len);
336                 assert_eq!(decoded_len_estimate, estimate.decoded_len_estimate());
337             }
338         }
339     }
340 
341     #[test]
estimate_via_u128_inflation()342     fn estimate_via_u128_inflation() {
343         // cover both ends of usize
344         (0..1000)
345             .chain(usize::MAX - 1000..=usize::MAX)
346             .for_each(|encoded_len| {
347                 // inflate to 128 bit type to be able to safely use the easy formulas
348                 let len_128 = encoded_len as u128;
349 
350                 let estimate = GeneralPurposeEstimate::new(encoded_len);
351                 assert_eq!(
352                     (len_128 + 3) / 4 * 3,
353                     estimate.conservative_decoded_len as u128
354                 );
355             })
356     }
357 }
358