1 // rstest_reuse template functions have unused variables
2 #![allow(unused_variables)]
3 
4 use rand::{
5     self,
6     distributions::{self, Distribution as _},
7     rngs, Rng as _, SeedableRng as _,
8 };
9 use rstest::rstest;
10 use rstest_reuse::{apply, template};
11 use std::{collections, fmt, io::Read as _};
12 
13 use crate::{
14     alphabet::{Alphabet, STANDARD},
15     encode::add_padding,
16     encoded_len,
17     engine::{
18         general_purpose, naive, Config, DecodeEstimate, DecodeMetadata, DecodePaddingMode, Engine,
19     },
20     read::DecoderReader,
21     tests::{assert_encode_sanity, random_alphabet, random_config},
22     DecodeError, DecodeSliceError, PAD_BYTE,
23 };
24 
25 // the case::foo syntax includes the "foo" in the generated test method names
26 #[template]
27 #[rstest(engine_wrapper,
28 case::general_purpose(GeneralPurposeWrapper {}),
29 case::naive(NaiveWrapper {}),
30 case::decoder_reader(DecoderReaderEngineWrapper {}),
31 )]
all_engines<E: EngineWrapper>(engine_wrapper: E)32 fn all_engines<E: EngineWrapper>(engine_wrapper: E) {}
33 
34 /// Some decode tests don't make sense for use with `DecoderReader` as they are difficult to
35 /// reason about or otherwise inapplicable given how DecoderReader slice up its input along
36 /// chunk boundaries.
37 #[template]
38 #[rstest(engine_wrapper,
39 case::general_purpose(GeneralPurposeWrapper {}),
40 case::naive(NaiveWrapper {}),
41 )]
all_engines_except_decoder_reader<E: EngineWrapper>(engine_wrapper: E)42 fn all_engines_except_decoder_reader<E: EngineWrapper>(engine_wrapper: E) {}
43 
44 #[apply(all_engines)]
rfc_test_vectors_std_alphabet<E: EngineWrapper>(engine_wrapper: E)45 fn rfc_test_vectors_std_alphabet<E: EngineWrapper>(engine_wrapper: E) {
46     let data = vec![
47         ("", ""),
48         ("f", "Zg=="),
49         ("fo", "Zm8="),
50         ("foo", "Zm9v"),
51         ("foob", "Zm9vYg=="),
52         ("fooba", "Zm9vYmE="),
53         ("foobar", "Zm9vYmFy"),
54     ];
55 
56     let engine = E::standard();
57     let engine_no_padding = E::standard_unpadded();
58 
59     for (orig, encoded) in &data {
60         let encoded_without_padding = encoded.trim_end_matches('=');
61 
62         // unpadded
63         {
64             let mut encode_buf = [0_u8; 8];
65             let mut decode_buf = [0_u8; 6];
66 
67             let encode_len =
68                 engine_no_padding.internal_encode(orig.as_bytes(), &mut encode_buf[..]);
69             assert_eq!(
70                 &encoded_without_padding,
71                 &std::str::from_utf8(&encode_buf[0..encode_len]).unwrap()
72             );
73             let decode_len = engine_no_padding
74                 .decode_slice_unchecked(encoded_without_padding.as_bytes(), &mut decode_buf[..])
75                 .unwrap();
76             assert_eq!(orig.len(), decode_len);
77 
78             assert_eq!(
79                 orig,
80                 &std::str::from_utf8(&decode_buf[0..decode_len]).unwrap()
81             );
82 
83             // if there was any padding originally, the no padding engine won't decode it
84             if encoded.as_bytes().contains(&PAD_BYTE) {
85                 assert_eq!(
86                     Err(DecodeError::InvalidPadding),
87                     engine_no_padding.decode(encoded)
88                 )
89             }
90         }
91 
92         // padded
93         {
94             let mut encode_buf = [0_u8; 8];
95             let mut decode_buf = [0_u8; 6];
96 
97             let encode_len = engine.internal_encode(orig.as_bytes(), &mut encode_buf[..]);
98             assert_eq!(
99                 // doesn't have padding added yet
100                 &encoded_without_padding,
101                 &std::str::from_utf8(&encode_buf[0..encode_len]).unwrap()
102             );
103             let pad_len = add_padding(encode_len, &mut encode_buf[encode_len..]);
104             assert_eq!(encoded.as_bytes(), &encode_buf[..encode_len + pad_len]);
105 
106             let decode_len = engine
107                 .decode_slice_unchecked(encoded.as_bytes(), &mut decode_buf[..])
108                 .unwrap();
109             assert_eq!(orig.len(), decode_len);
110 
111             assert_eq!(
112                 orig,
113                 &std::str::from_utf8(&decode_buf[0..decode_len]).unwrap()
114             );
115 
116             // if there was (canonical) padding, and we remove it, the standard engine won't decode
117             if encoded.as_bytes().contains(&PAD_BYTE) {
118                 assert_eq!(
119                     Err(DecodeError::InvalidPadding),
120                     engine.decode(encoded_without_padding)
121                 )
122             }
123         }
124     }
125 }
126 
127 #[apply(all_engines)]
roundtrip_random<E: EngineWrapper>(engine_wrapper: E)128 fn roundtrip_random<E: EngineWrapper>(engine_wrapper: E) {
129     let mut rng = seeded_rng();
130 
131     let mut orig_data = Vec::<u8>::new();
132     let mut encode_buf = Vec::<u8>::new();
133     let mut decode_buf = Vec::<u8>::new();
134 
135     let len_range = distributions::Uniform::new(1, 1_000);
136 
137     for _ in 0..10_000 {
138         let engine = E::random(&mut rng);
139 
140         orig_data.clear();
141         encode_buf.clear();
142         decode_buf.clear();
143 
144         let (orig_len, _, encoded_len) = generate_random_encoded_data(
145             &engine,
146             &mut orig_data,
147             &mut encode_buf,
148             &mut rng,
149             &len_range,
150         );
151 
152         // exactly the right size
153         decode_buf.resize(orig_len, 0);
154 
155         let dec_len = engine
156             .decode_slice_unchecked(&encode_buf[0..encoded_len], &mut decode_buf[..])
157             .unwrap();
158 
159         assert_eq!(orig_len, dec_len);
160         assert_eq!(&orig_data[..], &decode_buf[..dec_len]);
161     }
162 }
163 
164 #[apply(all_engines)]
encode_doesnt_write_extra_bytes<E: EngineWrapper>(engine_wrapper: E)165 fn encode_doesnt_write_extra_bytes<E: EngineWrapper>(engine_wrapper: E) {
166     let mut rng = seeded_rng();
167 
168     let mut orig_data = Vec::<u8>::new();
169     let mut encode_buf = Vec::<u8>::new();
170     let mut encode_buf_backup = Vec::<u8>::new();
171 
172     let input_len_range = distributions::Uniform::new(0, 1000);
173 
174     for _ in 0..10_000 {
175         let engine = E::random(&mut rng);
176         let padded = engine.config().encode_padding();
177 
178         orig_data.clear();
179         encode_buf.clear();
180         encode_buf_backup.clear();
181 
182         let orig_len = fill_rand(&mut orig_data, &mut rng, &input_len_range);
183 
184         let prefix_len = 1024;
185         // plenty of prefix and suffix
186         fill_rand_len(&mut encode_buf, &mut rng, prefix_len * 2 + orig_len * 2);
187         encode_buf_backup.extend_from_slice(&encode_buf[..]);
188 
189         let expected_encode_len_no_pad = encoded_len(orig_len, false).unwrap();
190 
191         let encoded_len_no_pad =
192             engine.internal_encode(&orig_data[..], &mut encode_buf[prefix_len..]);
193         assert_eq!(expected_encode_len_no_pad, encoded_len_no_pad);
194 
195         // no writes past what it claimed to write
196         assert_eq!(&encode_buf_backup[..prefix_len], &encode_buf[..prefix_len]);
197         assert_eq!(
198             &encode_buf_backup[(prefix_len + encoded_len_no_pad)..],
199             &encode_buf[(prefix_len + encoded_len_no_pad)..]
200         );
201 
202         let encoded_data = &encode_buf[prefix_len..(prefix_len + encoded_len_no_pad)];
203         assert_encode_sanity(
204             std::str::from_utf8(encoded_data).unwrap(),
205             // engines don't pad
206             false,
207             orig_len,
208         );
209 
210         // pad so we can decode it in case our random engine requires padding
211         let pad_len = if padded {
212             add_padding(
213                 encoded_len_no_pad,
214                 &mut encode_buf[prefix_len + encoded_len_no_pad..],
215             )
216         } else {
217             0
218         };
219 
220         assert_eq!(
221             orig_data,
222             engine
223                 .decode(&encode_buf[prefix_len..(prefix_len + encoded_len_no_pad + pad_len)],)
224                 .unwrap()
225         );
226     }
227 }
228 
229 #[apply(all_engines)]
encode_engine_slice_fits_into_precisely_sized_slice<E: EngineWrapper>(engine_wrapper: E)230 fn encode_engine_slice_fits_into_precisely_sized_slice<E: EngineWrapper>(engine_wrapper: E) {
231     let mut orig_data = Vec::new();
232     let mut encoded_data = Vec::new();
233     let mut decoded = Vec::new();
234 
235     let input_len_range = distributions::Uniform::new(0, 1000);
236 
237     let mut rng = rngs::SmallRng::from_entropy();
238 
239     for _ in 0..10_000 {
240         orig_data.clear();
241         encoded_data.clear();
242         decoded.clear();
243 
244         let input_len = input_len_range.sample(&mut rng);
245 
246         for _ in 0..input_len {
247             orig_data.push(rng.gen());
248         }
249 
250         let engine = E::random(&mut rng);
251 
252         let encoded_size = encoded_len(input_len, engine.config().encode_padding()).unwrap();
253 
254         encoded_data.resize(encoded_size, 0);
255 
256         assert_eq!(
257             encoded_size,
258             engine.encode_slice(&orig_data, &mut encoded_data).unwrap()
259         );
260 
261         assert_encode_sanity(
262             std::str::from_utf8(&encoded_data[0..encoded_size]).unwrap(),
263             engine.config().encode_padding(),
264             input_len,
265         );
266 
267         engine
268             .decode_vec(&encoded_data[0..encoded_size], &mut decoded)
269             .unwrap();
270         assert_eq!(orig_data, decoded);
271     }
272 }
273 
274 #[apply(all_engines)]
decode_doesnt_write_extra_bytes<E>(engine_wrapper: E) where E: EngineWrapper, <<E as EngineWrapper>::Engine as Engine>::Config: fmt::Debug,275 fn decode_doesnt_write_extra_bytes<E>(engine_wrapper: E)
276 where
277     E: EngineWrapper,
278     <<E as EngineWrapper>::Engine as Engine>::Config: fmt::Debug,
279 {
280     let mut rng = seeded_rng();
281 
282     let mut orig_data = Vec::<u8>::new();
283     let mut encode_buf = Vec::<u8>::new();
284     let mut decode_buf = Vec::<u8>::new();
285     let mut decode_buf_backup = Vec::<u8>::new();
286 
287     let len_range = distributions::Uniform::new(1, 1_000);
288 
289     for _ in 0..10_000 {
290         let engine = E::random(&mut rng);
291 
292         orig_data.clear();
293         encode_buf.clear();
294         decode_buf.clear();
295         decode_buf_backup.clear();
296 
297         let orig_len = fill_rand(&mut orig_data, &mut rng, &len_range);
298         encode_buf.resize(orig_len * 2 + 100, 0);
299 
300         let encoded_len = engine
301             .encode_slice(&orig_data[..], &mut encode_buf[..])
302             .unwrap();
303         encode_buf.truncate(encoded_len);
304 
305         // oversize decode buffer so we can easily tell if it writes anything more than
306         // just the decoded data
307         let prefix_len = 1024;
308         // plenty of prefix and suffix
309         fill_rand_len(&mut decode_buf, &mut rng, prefix_len * 2 + orig_len * 2);
310         decode_buf_backup.extend_from_slice(&decode_buf[..]);
311 
312         let dec_len = engine
313             .decode_slice_unchecked(&encode_buf, &mut decode_buf[prefix_len..])
314             .unwrap();
315 
316         assert_eq!(orig_len, dec_len);
317         assert_eq!(
318             &orig_data[..],
319             &decode_buf[prefix_len..prefix_len + dec_len]
320         );
321         assert_eq!(&decode_buf_backup[..prefix_len], &decode_buf[..prefix_len]);
322         assert_eq!(
323             &decode_buf_backup[prefix_len + dec_len..],
324             &decode_buf[prefix_len + dec_len..]
325         );
326     }
327 }
328 
329 #[apply(all_engines)]
decode_detect_invalid_last_symbol<E: EngineWrapper>(engine_wrapper: E)330 fn decode_detect_invalid_last_symbol<E: EngineWrapper>(engine_wrapper: E) {
331     // 0xFF -> "/w==", so all letters > w, 0-9, and '+', '/' should get InvalidLastSymbol
332     let engine = E::standard();
333 
334     assert_eq!(Ok(vec![0x89, 0x85]), engine.decode("iYU="));
335     assert_eq!(Ok(vec![0xFF]), engine.decode("/w=="));
336 
337     for (suffix, offset) in vec![
338         // suffix, offset of bad byte from start of suffix
339         ("/x==", 1_usize),
340         ("/z==", 1_usize),
341         ("/0==", 1_usize),
342         ("/9==", 1_usize),
343         ("/+==", 1_usize),
344         ("//==", 1_usize),
345         // trailing 01
346         ("iYV=", 2_usize),
347         // trailing 10
348         ("iYW=", 2_usize),
349         // trailing 11
350         ("iYX=", 2_usize),
351     ] {
352         for prefix_quads in 0..256 {
353             let mut encoded = "AAAA".repeat(prefix_quads);
354             encoded.push_str(suffix);
355 
356             assert_eq!(
357                 Err(DecodeError::InvalidLastSymbol(
358                     encoded.len() - 4 + offset,
359                     suffix.as_bytes()[offset],
360                 )),
361                 engine.decode(encoded.as_str())
362             );
363         }
364     }
365 }
366 
367 #[apply(all_engines)]
decode_detect_1_valid_symbol_in_last_quad_invalid_length<E: EngineWrapper>(engine_wrapper: E)368 fn decode_detect_1_valid_symbol_in_last_quad_invalid_length<E: EngineWrapper>(engine_wrapper: E) {
369     for len in (0_usize..256).map(|len| len * 4 + 1) {
370         for mode in all_pad_modes() {
371             let mut input = vec![b'A'; len];
372 
373             let engine = E::standard_with_pad_mode(true, mode);
374 
375             assert_eq!(Err(DecodeError::InvalidLength(len)), engine.decode(&input));
376             // if we add padding, then the first pad byte in the quad is invalid because it should
377             // be the second symbol
378             for _ in 0..3 {
379                 input.push(PAD_BYTE);
380                 assert_eq!(
381                     Err(DecodeError::InvalidByte(len, PAD_BYTE)),
382                     engine.decode(&input)
383                 );
384             }
385         }
386     }
387 }
388 
389 #[apply(all_engines)]
decode_detect_1_invalid_byte_in_last_quad_invalid_byte<E: EngineWrapper>(engine_wrapper: E)390 fn decode_detect_1_invalid_byte_in_last_quad_invalid_byte<E: EngineWrapper>(engine_wrapper: E) {
391     for prefix_len in (0_usize..256).map(|len| len * 4) {
392         for mode in all_pad_modes() {
393             let mut input = vec![b'A'; prefix_len];
394             input.push(b'*');
395 
396             let engine = E::standard_with_pad_mode(true, mode);
397 
398             assert_eq!(
399                 Err(DecodeError::InvalidByte(prefix_len, b'*')),
400                 engine.decode(&input)
401             );
402             // adding padding doesn't matter
403             for _ in 0..3 {
404                 input.push(PAD_BYTE);
405                 assert_eq!(
406                     Err(DecodeError::InvalidByte(prefix_len, b'*')),
407                     engine.decode(&input)
408                 );
409             }
410         }
411     }
412 }
413 
414 #[apply(all_engines)]
decode_detect_invalid_last_symbol_every_possible_two_symbols<E: EngineWrapper>( engine_wrapper: E, )415 fn decode_detect_invalid_last_symbol_every_possible_two_symbols<E: EngineWrapper>(
416     engine_wrapper: E,
417 ) {
418     let engine = E::standard();
419 
420     let mut base64_to_bytes = collections::HashMap::new();
421 
422     for b in 0_u8..=255 {
423         let mut b64 = vec![0_u8; 4];
424         assert_eq!(2, engine.internal_encode(&[b], &mut b64[..]));
425         let _ = add_padding(2, &mut b64[2..]);
426 
427         assert!(base64_to_bytes.insert(b64, vec![b]).is_none());
428     }
429 
430     // every possible combination of trailing symbols must either decode to 1 byte or get InvalidLastSymbol, with or without any leading chunks
431 
432     let mut prefix = Vec::new();
433     for _ in 0..256 {
434         let mut clone = prefix.clone();
435 
436         let mut symbols = [0_u8; 4];
437         for &s1 in STANDARD.symbols.iter() {
438             symbols[0] = s1;
439             for &s2 in STANDARD.symbols.iter() {
440                 symbols[1] = s2;
441                 symbols[2] = PAD_BYTE;
442                 symbols[3] = PAD_BYTE;
443 
444                 // chop off previous symbols
445                 clone.truncate(prefix.len());
446                 clone.extend_from_slice(&symbols[..]);
447                 let decoded_prefix_len = prefix.len() / 4 * 3;
448 
449                 match base64_to_bytes.get(&symbols[..]) {
450                     Some(bytes) => {
451                         let res = engine
452                             .decode(&clone)
453                             // remove prefix
454                             .map(|decoded| decoded[decoded_prefix_len..].to_vec());
455 
456                         assert_eq!(Ok(bytes.clone()), res);
457                     }
458                     None => assert_eq!(
459                         Err(DecodeError::InvalidLastSymbol(1, s2)),
460                         engine.decode(&symbols[..])
461                     ),
462                 }
463             }
464         }
465 
466         prefix.extend_from_slice(b"AAAA");
467     }
468 }
469 
470 #[apply(all_engines)]
decode_detect_invalid_last_symbol_every_possible_three_symbols<E: EngineWrapper>( engine_wrapper: E, )471 fn decode_detect_invalid_last_symbol_every_possible_three_symbols<E: EngineWrapper>(
472     engine_wrapper: E,
473 ) {
474     let engine = E::standard();
475 
476     let mut base64_to_bytes = collections::HashMap::new();
477 
478     let mut bytes = [0_u8; 2];
479     for b1 in 0_u8..=255 {
480         bytes[0] = b1;
481         for b2 in 0_u8..=255 {
482             bytes[1] = b2;
483             let mut b64 = vec![0_u8; 4];
484             assert_eq!(3, engine.internal_encode(&bytes, &mut b64[..]));
485             let _ = add_padding(3, &mut b64[3..]);
486 
487             let mut v = Vec::with_capacity(2);
488             v.extend_from_slice(&bytes[..]);
489 
490             assert!(base64_to_bytes.insert(b64, v).is_none());
491         }
492     }
493 
494     // every possible combination of symbols must either decode to 2 bytes or get InvalidLastSymbol, with or without any leading chunks
495 
496     let mut prefix = Vec::new();
497     let mut input = Vec::new();
498     for _ in 0..256 {
499         input.clear();
500         input.extend_from_slice(&prefix);
501 
502         let mut symbols = [0_u8; 4];
503         for &s1 in STANDARD.symbols.iter() {
504             symbols[0] = s1;
505             for &s2 in STANDARD.symbols.iter() {
506                 symbols[1] = s2;
507                 for &s3 in STANDARD.symbols.iter() {
508                     symbols[2] = s3;
509                     symbols[3] = PAD_BYTE;
510 
511                     // chop off previous symbols
512                     input.truncate(prefix.len());
513                     input.extend_from_slice(&symbols[..]);
514                     let decoded_prefix_len = prefix.len() / 4 * 3;
515 
516                     match base64_to_bytes.get(&symbols[..]) {
517                         Some(bytes) => {
518                             let res = engine
519                                 .decode(&input)
520                                 // remove prefix
521                                 .map(|decoded| decoded[decoded_prefix_len..].to_vec());
522 
523                             assert_eq!(Ok(bytes.clone()), res);
524                         }
525                         None => assert_eq!(
526                             Err(DecodeError::InvalidLastSymbol(2, s3)),
527                             engine.decode(&symbols[..])
528                         ),
529                     }
530                 }
531             }
532         }
533         prefix.extend_from_slice(b"AAAA");
534     }
535 }
536 
537 #[apply(all_engines)]
decode_invalid_trailing_bits_ignored_when_configured<E: EngineWrapper>(engine_wrapper: E)538 fn decode_invalid_trailing_bits_ignored_when_configured<E: EngineWrapper>(engine_wrapper: E) {
539     let strict = E::standard();
540     let forgiving = E::standard_allow_trailing_bits();
541 
542     fn assert_tolerant_decode<E: Engine>(
543         engine: &E,
544         input: &mut String,
545         b64_prefix_len: usize,
546         expected_decode_bytes: Vec<u8>,
547         data: &str,
548     ) {
549         let prefixed = prefixed_data(input, b64_prefix_len, data);
550         let decoded = engine.decode(prefixed);
551         // prefix is always complete chunks
552         let decoded_prefix_len = b64_prefix_len / 4 * 3;
553         assert_eq!(
554             Ok(expected_decode_bytes),
555             decoded.map(|v| v[decoded_prefix_len..].to_vec())
556         );
557     }
558 
559     let mut prefix = String::new();
560     for _ in 0..256 {
561         let mut input = prefix.clone();
562 
563         // example from https://github.com/marshallpierce/rust-base64/issues/75
564         assert!(strict
565             .decode(prefixed_data(&mut input, prefix.len(), "/w=="))
566             .is_ok());
567         assert!(strict
568             .decode(prefixed_data(&mut input, prefix.len(), "iYU="))
569             .is_ok());
570         // trailing 01
571         assert_tolerant_decode(&forgiving, &mut input, prefix.len(), vec![255], "/x==");
572         assert_tolerant_decode(&forgiving, &mut input, prefix.len(), vec![137, 133], "iYV=");
573         // trailing 10
574         assert_tolerant_decode(&forgiving, &mut input, prefix.len(), vec![255], "/y==");
575         assert_tolerant_decode(&forgiving, &mut input, prefix.len(), vec![137, 133], "iYW=");
576         // trailing 11
577         assert_tolerant_decode(&forgiving, &mut input, prefix.len(), vec![255], "/z==");
578         assert_tolerant_decode(&forgiving, &mut input, prefix.len(), vec![137, 133], "iYX=");
579 
580         prefix.push_str("AAAA");
581     }
582 }
583 
584 #[apply(all_engines)]
decode_invalid_byte_error<E: EngineWrapper>(engine_wrapper: E)585 fn decode_invalid_byte_error<E: EngineWrapper>(engine_wrapper: E) {
586     let mut rng = seeded_rng();
587 
588     let mut orig_data = Vec::<u8>::new();
589     let mut encode_buf = Vec::<u8>::new();
590     let mut decode_buf = Vec::<u8>::new();
591 
592     let len_range = distributions::Uniform::new(1, 1_000);
593 
594     for _ in 0..100_000 {
595         let alphabet = random_alphabet(&mut rng);
596         let engine = E::random_alphabet(&mut rng, alphabet);
597 
598         orig_data.clear();
599         encode_buf.clear();
600         decode_buf.clear();
601 
602         let (orig_len, encoded_len_just_data, encoded_len_with_padding) =
603             generate_random_encoded_data(
604                 &engine,
605                 &mut orig_data,
606                 &mut encode_buf,
607                 &mut rng,
608                 &len_range,
609             );
610 
611         // exactly the right size
612         decode_buf.resize(orig_len, 0);
613 
614         // replace one encoded byte with an invalid byte
615         let invalid_byte: u8 = loop {
616             let byte: u8 = rng.gen();
617 
618             if alphabet.symbols.contains(&byte) || byte == PAD_BYTE {
619                 continue;
620             } else {
621                 break byte;
622             }
623         };
624 
625         let invalid_range = distributions::Uniform::new(0, orig_len);
626         let invalid_index = invalid_range.sample(&mut rng);
627         encode_buf[invalid_index] = invalid_byte;
628 
629         assert_eq!(
630             Err(DecodeError::InvalidByte(invalid_index, invalid_byte)),
631             engine.decode_slice_unchecked(
632                 &encode_buf[0..encoded_len_with_padding],
633                 &mut decode_buf[..],
634             )
635         );
636     }
637 }
638 
639 /// Any amount of padding anywhere before the final non padding character = invalid byte at first
640 /// pad byte.
641 /// From this and [decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad_non_canonical_padding_suffix_all_modes],
642 /// we know padding must extend contiguously to the end of the input.
643 #[apply(all_engines)]
decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad_all_modes< E: EngineWrapper, >( engine_wrapper: E, )644 fn decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad_all_modes<
645     E: EngineWrapper,
646 >(
647     engine_wrapper: E,
648 ) {
649     // Different amounts of padding, w/ offset from end for the last non-padding char.
650     // Only canonical padding, so Canonical mode will work.
651     let suffixes = &[("AA==", 2), ("AAA=", 1), ("AAAA", 0)];
652 
653     for mode in pad_modes_allowing_padding() {
654         // We don't encode, so we don't care about encode padding.
655         let engine = E::standard_with_pad_mode(true, mode);
656 
657         decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad(
658             engine,
659             suffixes.as_slice(),
660         );
661     }
662 }
663 
664 /// See [decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad_all_modes]
665 #[apply(all_engines)]
decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad_non_canonical_padding_suffix< E: EngineWrapper, >( engine_wrapper: E, )666 fn decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad_non_canonical_padding_suffix<
667     E: EngineWrapper,
668 >(
669     engine_wrapper: E,
670 ) {
671     // Different amounts of padding, w/ offset from end for the last non-padding char, and
672     // non-canonical padding.
673     let suffixes = [
674         ("AA==", 2),
675         ("AA=", 1),
676         ("AA", 0),
677         ("AAA=", 1),
678         ("AAA", 0),
679         ("AAAA", 0),
680     ];
681 
682     // We don't encode, so we don't care about encode padding.
683     // Decoding is indifferent so that we don't get caught by missing padding on the last quad
684     let engine = E::standard_with_pad_mode(true, DecodePaddingMode::Indifferent);
685 
686     decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad(
687         engine,
688         suffixes.as_slice(),
689     )
690 }
691 
decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad( engine: impl Engine, suffixes: &[(&str, usize)], )692 fn decode_padding_before_final_non_padding_char_error_invalid_byte_at_first_pad(
693     engine: impl Engine,
694     suffixes: &[(&str, usize)],
695 ) {
696     let mut rng = seeded_rng();
697 
698     let prefix_quads_range = distributions::Uniform::from(0..=256);
699 
700     for _ in 0..100_000 {
701         for (suffix, suffix_offset) in suffixes.iter() {
702             let mut s = "AAAA".repeat(prefix_quads_range.sample(&mut rng));
703             s.push_str(suffix);
704             let mut encoded = s.into_bytes();
705 
706             // calculate a range to write padding into that leaves at least one non padding char
707             let last_non_padding_offset = encoded.len() - 1 - suffix_offset;
708 
709             // don't include last non padding char as it must stay not padding
710             let padding_end = rng.gen_range(0..last_non_padding_offset);
711 
712             // don't use more than 100 bytes of padding, but also use shorter lengths when
713             // padding_end is near the start of the encoded data to avoid biasing to padding
714             // the entire prefix on short lengths
715             let padding_len = rng.gen_range(1..=usize::min(100, padding_end + 1));
716             let padding_start = padding_end.saturating_sub(padding_len);
717 
718             encoded[padding_start..=padding_end].fill(PAD_BYTE);
719 
720             // should still have non-padding before any final padding
721             assert_ne!(PAD_BYTE, encoded[last_non_padding_offset]);
722             assert_eq!(
723                 Err(DecodeError::InvalidByte(padding_start, PAD_BYTE)),
724                 engine.decode(&encoded),
725                 "len: {}, input: {}",
726                 encoded.len(),
727                 String::from_utf8(encoded).unwrap()
728             );
729         }
730     }
731 }
732 
733 /// Any amount of padding before final chunk that crosses over into final chunk with 1-4 bytes =
734 /// invalid byte at first pad byte.
735 /// From this we know the padding must start in the final chunk.
736 #[apply(all_engines)]
decode_padding_starts_before_final_chunk_error_invalid_byte_at_first_pad<E: EngineWrapper>( engine_wrapper: E, )737 fn decode_padding_starts_before_final_chunk_error_invalid_byte_at_first_pad<E: EngineWrapper>(
738     engine_wrapper: E,
739 ) {
740     let mut rng = seeded_rng();
741 
742     // must have at least one prefix quad
743     let prefix_quads_range = distributions::Uniform::from(1..256);
744     let suffix_pad_len_range = distributions::Uniform::from(1..=4);
745     // don't use no-padding mode, as the reader decode might decode a block that ends with
746     // valid padding, which should then be referenced when encountering the later invalid byte
747     for mode in pad_modes_allowing_padding() {
748         // we don't encode so we don't care about encode padding
749         let engine = E::standard_with_pad_mode(true, mode);
750         for _ in 0..100_000 {
751             let suffix_len = suffix_pad_len_range.sample(&mut rng);
752             // all 0 bits so we don't hit InvalidLastSymbol with the reader decoder
753             let mut encoded = "AAAA"
754                 .repeat(prefix_quads_range.sample(&mut rng))
755                 .into_bytes();
756             encoded.resize(encoded.len() + suffix_len, PAD_BYTE);
757 
758             // amount of padding must be long enough to extend back from suffix into previous
759             // quads
760             let padding_len = rng.gen_range(suffix_len + 1..encoded.len());
761             // no non-padding after padding in this test, so padding goes to the end
762             let padding_start = encoded.len() - padding_len;
763             encoded[padding_start..].fill(PAD_BYTE);
764 
765             assert_eq!(
766                 Err(DecodeError::InvalidByte(padding_start, PAD_BYTE)),
767                 engine.decode(&encoded),
768                 "suffix_len: {}, padding_len: {}, b64: {}",
769                 suffix_len,
770                 padding_len,
771                 std::str::from_utf8(&encoded).unwrap()
772             );
773         }
774     }
775 }
776 
777 /// 0-1 bytes of data before any amount of padding in final chunk = invalid byte, since padding
778 /// is not valid data (consistent with error for pad bytes in earlier chunks).
779 /// From this we know there must be 2-3 bytes of data before padding
780 #[apply(all_engines)]
decode_too_little_data_before_padding_error_invalid_byte<E: EngineWrapper>(engine_wrapper: E)781 fn decode_too_little_data_before_padding_error_invalid_byte<E: EngineWrapper>(engine_wrapper: E) {
782     let mut rng = seeded_rng();
783 
784     // want to test no prefix quad case, so start at 0
785     let prefix_quads_range = distributions::Uniform::from(0_usize..256);
786     let suffix_data_len_range = distributions::Uniform::from(0_usize..=1);
787     for mode in all_pad_modes() {
788         // we don't encode so we don't care about encode padding
789         let engine = E::standard_with_pad_mode(true, mode);
790         for _ in 0..100_000 {
791             let suffix_data_len = suffix_data_len_range.sample(&mut rng);
792             let prefix_quad_len = prefix_quads_range.sample(&mut rng);
793 
794             // for all possible padding lengths
795             for padding_len in 1..=(4 - suffix_data_len) {
796                 let mut encoded = "ABCD".repeat(prefix_quad_len).into_bytes();
797                 encoded.resize(encoded.len() + suffix_data_len, b'A');
798                 encoded.resize(encoded.len() + padding_len, PAD_BYTE);
799 
800                 assert_eq!(
801                     Err(DecodeError::InvalidByte(
802                         prefix_quad_len * 4 + suffix_data_len,
803                         PAD_BYTE,
804                     )),
805                     engine.decode(&encoded),
806                     "input {} suffix data len {} pad len {}",
807                     String::from_utf8(encoded).unwrap(),
808                     suffix_data_len,
809                     padding_len
810                 );
811             }
812         }
813     }
814 }
815 
816 // https://eprint.iacr.org/2022/361.pdf table 2, test 1
817 #[apply(all_engines)]
decode_malleability_test_case_3_byte_suffix_valid<E: EngineWrapper>(engine_wrapper: E)818 fn decode_malleability_test_case_3_byte_suffix_valid<E: EngineWrapper>(engine_wrapper: E) {
819     assert_eq!(
820         b"Hello".as_slice(),
821         &E::standard().decode("SGVsbG8=").unwrap()
822     );
823 }
824 
825 // https://eprint.iacr.org/2022/361.pdf table 2, test 2
826 #[apply(all_engines)]
decode_malleability_test_case_3_byte_suffix_invalid_trailing_symbol<E: EngineWrapper>( engine_wrapper: E, )827 fn decode_malleability_test_case_3_byte_suffix_invalid_trailing_symbol<E: EngineWrapper>(
828     engine_wrapper: E,
829 ) {
830     assert_eq!(
831         DecodeError::InvalidLastSymbol(6, 0x39),
832         E::standard().decode("SGVsbG9=").unwrap_err()
833     );
834 }
835 
836 // https://eprint.iacr.org/2022/361.pdf table 2, test 3
837 #[apply(all_engines)]
decode_malleability_test_case_3_byte_suffix_no_padding<E: EngineWrapper>(engine_wrapper: E)838 fn decode_malleability_test_case_3_byte_suffix_no_padding<E: EngineWrapper>(engine_wrapper: E) {
839     assert_eq!(
840         DecodeError::InvalidPadding,
841         E::standard().decode("SGVsbG9").unwrap_err()
842     );
843 }
844 
845 // https://eprint.iacr.org/2022/361.pdf table 2, test 4
846 #[apply(all_engines)]
decode_malleability_test_case_2_byte_suffix_valid_two_padding_symbols<E: EngineWrapper>( engine_wrapper: E, )847 fn decode_malleability_test_case_2_byte_suffix_valid_two_padding_symbols<E: EngineWrapper>(
848     engine_wrapper: E,
849 ) {
850     assert_eq!(
851         b"Hell".as_slice(),
852         &E::standard().decode("SGVsbA==").unwrap()
853     );
854 }
855 
856 // https://eprint.iacr.org/2022/361.pdf table 2, test 5
857 #[apply(all_engines)]
decode_malleability_test_case_2_byte_suffix_short_padding<E: EngineWrapper>(engine_wrapper: E)858 fn decode_malleability_test_case_2_byte_suffix_short_padding<E: EngineWrapper>(engine_wrapper: E) {
859     assert_eq!(
860         DecodeError::InvalidPadding,
861         E::standard().decode("SGVsbA=").unwrap_err()
862     );
863 }
864 
865 // https://eprint.iacr.org/2022/361.pdf table 2, test 6
866 #[apply(all_engines)]
decode_malleability_test_case_2_byte_suffix_no_padding<E: EngineWrapper>(engine_wrapper: E)867 fn decode_malleability_test_case_2_byte_suffix_no_padding<E: EngineWrapper>(engine_wrapper: E) {
868     assert_eq!(
869         DecodeError::InvalidPadding,
870         E::standard().decode("SGVsbA").unwrap_err()
871     );
872 }
873 
874 // https://eprint.iacr.org/2022/361.pdf table 2, test 7
875 // DecoderReader pseudo-engine gets InvalidByte at 8 (extra padding) since it decodes the first
876 // two complete quads correctly.
877 #[apply(all_engines_except_decoder_reader)]
decode_malleability_test_case_2_byte_suffix_too_much_padding<E: EngineWrapper>( engine_wrapper: E, )878 fn decode_malleability_test_case_2_byte_suffix_too_much_padding<E: EngineWrapper>(
879     engine_wrapper: E,
880 ) {
881     assert_eq!(
882         DecodeError::InvalidByte(6, PAD_BYTE),
883         E::standard().decode("SGVsbA====").unwrap_err()
884     );
885 }
886 
887 /// Requires canonical padding -> accepts 2 + 2, 3 + 1, 4 + 0 final quad configurations
888 #[apply(all_engines)]
decode_pad_mode_requires_canonical_accepts_canonical<E: EngineWrapper>(engine_wrapper: E)889 fn decode_pad_mode_requires_canonical_accepts_canonical<E: EngineWrapper>(engine_wrapper: E) {
890     assert_all_suffixes_ok(
891         E::standard_with_pad_mode(true, DecodePaddingMode::RequireCanonical),
892         vec!["/w==", "iYU=", "AAAA"],
893     );
894 }
895 
896 /// Requires canonical padding -> rejects 2 + 0-1, 3 + 0 final chunk configurations
897 #[apply(all_engines)]
decode_pad_mode_requires_canonical_rejects_non_canonical<E: EngineWrapper>(engine_wrapper: E)898 fn decode_pad_mode_requires_canonical_rejects_non_canonical<E: EngineWrapper>(engine_wrapper: E) {
899     let engine = E::standard_with_pad_mode(true, DecodePaddingMode::RequireCanonical);
900 
901     let suffixes = ["/w", "/w=", "iYU"];
902     for num_prefix_quads in 0..256 {
903         for &suffix in suffixes.iter() {
904             let mut encoded = "AAAA".repeat(num_prefix_quads);
905             encoded.push_str(suffix);
906 
907             let res = engine.decode(&encoded);
908 
909             assert_eq!(Err(DecodeError::InvalidPadding), res);
910         }
911     }
912 }
913 
914 /// Requires no padding -> accepts 2 + 0, 3 + 0, 4 + 0 final chunk configuration
915 #[apply(all_engines)]
decode_pad_mode_requires_no_padding_accepts_no_padding<E: EngineWrapper>(engine_wrapper: E)916 fn decode_pad_mode_requires_no_padding_accepts_no_padding<E: EngineWrapper>(engine_wrapper: E) {
917     assert_all_suffixes_ok(
918         E::standard_with_pad_mode(true, DecodePaddingMode::RequireNone),
919         vec!["/w", "iYU", "AAAA"],
920     );
921 }
922 
923 /// Requires no padding -> rejects 2 + 1-2, 3 + 1 final chunk configuration
924 #[apply(all_engines)]
decode_pad_mode_requires_no_padding_rejects_any_padding<E: EngineWrapper>(engine_wrapper: E)925 fn decode_pad_mode_requires_no_padding_rejects_any_padding<E: EngineWrapper>(engine_wrapper: E) {
926     let engine = E::standard_with_pad_mode(true, DecodePaddingMode::RequireNone);
927 
928     let suffixes = ["/w=", "/w==", "iYU="];
929     for num_prefix_quads in 0..256 {
930         for &suffix in suffixes.iter() {
931             let mut encoded = "AAAA".repeat(num_prefix_quads);
932             encoded.push_str(suffix);
933 
934             let res = engine.decode(&encoded);
935 
936             assert_eq!(Err(DecodeError::InvalidPadding), res);
937         }
938     }
939 }
940 
941 /// Indifferent padding accepts 2 + 0-2, 3 + 0-1, 4 + 0 final chunk configuration
942 #[apply(all_engines)]
decode_pad_mode_indifferent_padding_accepts_anything<E: EngineWrapper>(engine_wrapper: E)943 fn decode_pad_mode_indifferent_padding_accepts_anything<E: EngineWrapper>(engine_wrapper: E) {
944     assert_all_suffixes_ok(
945         E::standard_with_pad_mode(true, DecodePaddingMode::Indifferent),
946         vec!["/w", "/w=", "/w==", "iYU", "iYU=", "AAAA"],
947     );
948 }
949 
950 /// 1 trailing byte that's not padding is detected as invalid byte even though there's padding
951 /// in the middle of the input. This is essentially mandating the eager check for 1 trailing byte
952 /// to catch the \n suffix case.
953 // DecoderReader pseudo-engine can't handle DecodePaddingMode::RequireNone since it will decode
954 // a complete quad with padding in it before encountering the stray byte that makes it an invalid
955 // length
956 #[apply(all_engines_except_decoder_reader)]
decode_invalid_trailing_bytes_all_pad_modes_invalid_byte<E: EngineWrapper>(engine_wrapper: E)957 fn decode_invalid_trailing_bytes_all_pad_modes_invalid_byte<E: EngineWrapper>(engine_wrapper: E) {
958     for mode in all_pad_modes() {
959         do_invalid_trailing_byte(E::standard_with_pad_mode(true, mode), mode);
960     }
961 }
962 
963 #[apply(all_engines)]
decode_invalid_trailing_bytes_invalid_byte<E: EngineWrapper>(engine_wrapper: E)964 fn decode_invalid_trailing_bytes_invalid_byte<E: EngineWrapper>(engine_wrapper: E) {
965     // excluding no padding mode because the DecoderWrapper pseudo-engine will fail with
966     // InvalidPadding because it will decode the last complete quad with padding first
967     for mode in pad_modes_allowing_padding() {
968         do_invalid_trailing_byte(E::standard_with_pad_mode(true, mode), mode);
969     }
970 }
do_invalid_trailing_byte(engine: impl Engine, mode: DecodePaddingMode)971 fn do_invalid_trailing_byte(engine: impl Engine, mode: DecodePaddingMode) {
972     for last_byte in [b'*', b'\n'] {
973         for num_prefix_quads in 0..256 {
974             let mut s: String = "ABCD".repeat(num_prefix_quads);
975             s.push_str("Cg==");
976             let mut input = s.into_bytes();
977             input.push(last_byte);
978 
979             // The case of trailing newlines is common enough to warrant a test for a good error
980             // message.
981             assert_eq!(
982                 Err(DecodeError::InvalidByte(
983                     num_prefix_quads * 4 + 4,
984                     last_byte
985                 )),
986                 engine.decode(&input),
987                 "mode: {:?}, input: {}",
988                 mode,
989                 String::from_utf8(input).unwrap()
990             );
991         }
992     }
993 }
994 
995 /// When there's 1 trailing byte, but it's padding, it's only InvalidByte if there isn't padding
996 /// earlier.
997 #[apply(all_engines)]
decode_invalid_trailing_padding_as_invalid_byte_at_first_pad_byte<E: EngineWrapper>( engine_wrapper: E, )998 fn decode_invalid_trailing_padding_as_invalid_byte_at_first_pad_byte<E: EngineWrapper>(
999     engine_wrapper: E,
1000 ) {
1001     // excluding no padding mode because the DecoderWrapper pseudo-engine will fail with
1002     // InvalidPadding because it will decode the last complete quad with padding first
1003     for mode in pad_modes_allowing_padding() {
1004         do_invalid_trailing_padding_as_invalid_byte_at_first_padding(
1005             E::standard_with_pad_mode(true, mode),
1006             mode,
1007         );
1008     }
1009 }
1010 
1011 // DecoderReader pseudo-engine can't handle DecodePaddingMode::RequireNone since it will decode
1012 // a complete quad with padding in it before encountering the stray byte that makes it an invalid
1013 // length
1014 #[apply(all_engines_except_decoder_reader)]
decode_invalid_trailing_padding_as_invalid_byte_at_first_byte_all_modes<E: EngineWrapper>( engine_wrapper: E, )1015 fn decode_invalid_trailing_padding_as_invalid_byte_at_first_byte_all_modes<E: EngineWrapper>(
1016     engine_wrapper: E,
1017 ) {
1018     for mode in all_pad_modes() {
1019         do_invalid_trailing_padding_as_invalid_byte_at_first_padding(
1020             E::standard_with_pad_mode(true, mode),
1021             mode,
1022         );
1023     }
1024 }
do_invalid_trailing_padding_as_invalid_byte_at_first_padding( engine: impl Engine, mode: DecodePaddingMode, )1025 fn do_invalid_trailing_padding_as_invalid_byte_at_first_padding(
1026     engine: impl Engine,
1027     mode: DecodePaddingMode,
1028 ) {
1029     for num_prefix_quads in 0..256 {
1030         for (suffix, pad_offset) in [("AA===", 2), ("AAA==", 3), ("AAAA=", 4)] {
1031             let mut s: String = "ABCD".repeat(num_prefix_quads);
1032             s.push_str(suffix);
1033 
1034             assert_eq!(
1035                 // pad after `g`, not the last one
1036                 Err(DecodeError::InvalidByte(
1037                     num_prefix_quads * 4 + pad_offset,
1038                     PAD_BYTE
1039                 )),
1040                 engine.decode(&s),
1041                 "mode: {:?}, input: {}",
1042                 mode,
1043                 s
1044             );
1045         }
1046     }
1047 }
1048 
1049 #[apply(all_engines)]
decode_into_slice_fits_in_precisely_sized_slice<E: EngineWrapper>(engine_wrapper: E)1050 fn decode_into_slice_fits_in_precisely_sized_slice<E: EngineWrapper>(engine_wrapper: E) {
1051     let mut orig_data = Vec::new();
1052     let mut encoded_data = String::new();
1053     let mut decode_buf = Vec::new();
1054 
1055     let input_len_range = distributions::Uniform::new(0, 1000);
1056     let mut rng = rngs::SmallRng::from_entropy();
1057 
1058     for _ in 0..10_000 {
1059         orig_data.clear();
1060         encoded_data.clear();
1061         decode_buf.clear();
1062 
1063         let input_len = input_len_range.sample(&mut rng);
1064 
1065         for _ in 0..input_len {
1066             orig_data.push(rng.gen());
1067         }
1068 
1069         let engine = E::random(&mut rng);
1070         engine.encode_string(&orig_data, &mut encoded_data);
1071         assert_encode_sanity(&encoded_data, engine.config().encode_padding(), input_len);
1072 
1073         decode_buf.resize(input_len, 0);
1074         // decode into the non-empty buf
1075         let decode_bytes_written = engine
1076             .decode_slice_unchecked(encoded_data.as_bytes(), &mut decode_buf[..])
1077             .unwrap();
1078         assert_eq!(orig_data.len(), decode_bytes_written);
1079         assert_eq!(orig_data, decode_buf);
1080 
1081         // same for checked variant
1082         decode_buf.clear();
1083         decode_buf.resize(input_len, 0);
1084         // decode into the non-empty buf
1085         let decode_bytes_written = engine
1086             .decode_slice(encoded_data.as_bytes(), &mut decode_buf[..])
1087             .unwrap();
1088         assert_eq!(orig_data.len(), decode_bytes_written);
1089         assert_eq!(orig_data, decode_buf);
1090     }
1091 }
1092 
1093 #[apply(all_engines)]
inner_decode_reports_padding_position<E: EngineWrapper>(engine_wrapper: E)1094 fn inner_decode_reports_padding_position<E: EngineWrapper>(engine_wrapper: E) {
1095     let mut b64 = String::new();
1096     let mut decoded = Vec::new();
1097     let engine = E::standard();
1098 
1099     for pad_position in 1..10_000 {
1100         b64.clear();
1101         decoded.clear();
1102         // plenty of room for original data
1103         decoded.resize(pad_position, 0);
1104 
1105         for _ in 0..pad_position {
1106             b64.push('A');
1107         }
1108         // finish the quad with padding
1109         for _ in 0..(4 - (pad_position % 4)) {
1110             b64.push('=');
1111         }
1112 
1113         let decode_res = engine.internal_decode(
1114             b64.as_bytes(),
1115             &mut decoded[..],
1116             engine.internal_decoded_len_estimate(b64.len()),
1117         );
1118         if pad_position % 4 < 2 {
1119             // impossible padding
1120             assert_eq!(
1121                 Err(DecodeSliceError::DecodeError(DecodeError::InvalidByte(
1122                     pad_position,
1123                     PAD_BYTE
1124                 ))),
1125                 decode_res
1126             );
1127         } else {
1128             let decoded_bytes = pad_position / 4 * 3
1129                 + match pad_position % 4 {
1130                     0 => 0,
1131                     2 => 1,
1132                     3 => 2,
1133                     _ => unreachable!(),
1134                 };
1135             assert_eq!(
1136                 Ok(DecodeMetadata::new(decoded_bytes, Some(pad_position))),
1137                 decode_res
1138             );
1139         }
1140     }
1141 }
1142 
1143 #[apply(all_engines)]
decode_length_estimate_delta<E: EngineWrapper>(engine_wrapper: E)1144 fn decode_length_estimate_delta<E: EngineWrapper>(engine_wrapper: E) {
1145     for engine in [E::standard(), E::standard_unpadded()] {
1146         for &padding in &[true, false] {
1147             for orig_len in 0..1000 {
1148                 let encoded_len = encoded_len(orig_len, padding).unwrap();
1149 
1150                 let decoded_estimate = engine
1151                     .internal_decoded_len_estimate(encoded_len)
1152                     .decoded_len_estimate();
1153                 assert!(decoded_estimate >= orig_len);
1154                 assert!(
1155                     decoded_estimate - orig_len < 3,
1156                     "estimate: {}, encoded: {}, orig: {}",
1157                     decoded_estimate,
1158                     encoded_len,
1159                     orig_len
1160                 );
1161             }
1162         }
1163     }
1164 }
1165 
1166 #[apply(all_engines)]
estimate_via_u128_inflation<E: EngineWrapper>(engine_wrapper: E)1167 fn estimate_via_u128_inflation<E: EngineWrapper>(engine_wrapper: E) {
1168     // cover both ends of usize
1169     (0..1000)
1170         .chain(usize::MAX - 1000..=usize::MAX)
1171         .for_each(|encoded_len| {
1172             // inflate to 128 bit type to be able to safely use the easy formulas
1173             let len_128 = encoded_len as u128;
1174 
1175             let estimate = E::standard()
1176                 .internal_decoded_len_estimate(encoded_len)
1177                 .decoded_len_estimate();
1178 
1179             // This check is a little too strict: it requires using the (len + 3) / 4 * 3 formula
1180             // or equivalent, but until other engines come along that use a different formula
1181             // requiring that we think more carefully about what the allowable criteria are, this
1182             // will do.
1183             assert_eq!(
1184                 ((len_128 + 3) / 4 * 3) as usize,
1185                 estimate,
1186                 "enc len {}",
1187                 encoded_len
1188             );
1189         })
1190 }
1191 
1192 #[apply(all_engines)]
decode_slice_checked_fails_gracefully_at_all_output_lengths<E: EngineWrapper>( engine_wrapper: E, )1193 fn decode_slice_checked_fails_gracefully_at_all_output_lengths<E: EngineWrapper>(
1194     engine_wrapper: E,
1195 ) {
1196     let mut rng = seeded_rng();
1197     for original_len in 0..1000 {
1198         let mut original = vec![0; original_len];
1199         rng.fill(&mut original[..]);
1200 
1201         for mode in all_pad_modes() {
1202             let engine = E::standard_with_pad_mode(
1203                 match mode {
1204                     DecodePaddingMode::Indifferent | DecodePaddingMode::RequireCanonical => true,
1205                     DecodePaddingMode::RequireNone => false,
1206                 },
1207                 mode,
1208             );
1209 
1210             let encoded = engine.encode(&original);
1211             let mut decode_buf = Vec::with_capacity(original_len);
1212             for decode_buf_len in 0..original_len {
1213                 decode_buf.resize(decode_buf_len, 0);
1214                 assert_eq!(
1215                     DecodeSliceError::OutputSliceTooSmall,
1216                     engine
1217                         .decode_slice(&encoded, &mut decode_buf[..])
1218                         .unwrap_err(),
1219                     "original len: {}, encoded len: {}, buf len: {}, mode: {:?}",
1220                     original_len,
1221                     encoded.len(),
1222                     decode_buf_len,
1223                     mode
1224                 );
1225                 // internal method works the same
1226                 assert_eq!(
1227                     DecodeSliceError::OutputSliceTooSmall,
1228                     engine
1229                         .internal_decode(
1230                             encoded.as_bytes(),
1231                             &mut decode_buf[..],
1232                             engine.internal_decoded_len_estimate(encoded.len())
1233                         )
1234                         .unwrap_err()
1235                 );
1236             }
1237 
1238             decode_buf.resize(original_len, 0);
1239             rng.fill(&mut decode_buf[..]);
1240             assert_eq!(
1241                 original_len,
1242                 engine.decode_slice(&encoded, &mut decode_buf[..]).unwrap()
1243             );
1244             assert_eq!(original, decode_buf);
1245         }
1246     }
1247 }
1248 
1249 /// Returns a tuple of the original data length, the encoded data length (just data), and the length including padding.
1250 ///
1251 /// Vecs provided should be empty.
generate_random_encoded_data<E: Engine, R: rand::Rng, D: distributions::Distribution<usize>>( engine: &E, orig_data: &mut Vec<u8>, encode_buf: &mut Vec<u8>, rng: &mut R, length_distribution: &D, ) -> (usize, usize, usize)1252 fn generate_random_encoded_data<E: Engine, R: rand::Rng, D: distributions::Distribution<usize>>(
1253     engine: &E,
1254     orig_data: &mut Vec<u8>,
1255     encode_buf: &mut Vec<u8>,
1256     rng: &mut R,
1257     length_distribution: &D,
1258 ) -> (usize, usize, usize) {
1259     let padding: bool = engine.config().encode_padding();
1260 
1261     let orig_len = fill_rand(orig_data, rng, length_distribution);
1262     let expected_encoded_len = encoded_len(orig_len, padding).unwrap();
1263     encode_buf.resize(expected_encoded_len, 0);
1264 
1265     let base_encoded_len = engine.internal_encode(&orig_data[..], &mut encode_buf[..]);
1266 
1267     let enc_len_with_padding = if padding {
1268         base_encoded_len + add_padding(base_encoded_len, &mut encode_buf[base_encoded_len..])
1269     } else {
1270         base_encoded_len
1271     };
1272 
1273     assert_eq!(expected_encoded_len, enc_len_with_padding);
1274 
1275     (orig_len, base_encoded_len, enc_len_with_padding)
1276 }
1277 
1278 // fill to a random length
fill_rand<R: rand::Rng, D: distributions::Distribution<usize>>( vec: &mut Vec<u8>, rng: &mut R, length_distribution: &D, ) -> usize1279 fn fill_rand<R: rand::Rng, D: distributions::Distribution<usize>>(
1280     vec: &mut Vec<u8>,
1281     rng: &mut R,
1282     length_distribution: &D,
1283 ) -> usize {
1284     let len = length_distribution.sample(rng);
1285     for _ in 0..len {
1286         vec.push(rng.gen());
1287     }
1288 
1289     len
1290 }
1291 
fill_rand_len<R: rand::Rng>(vec: &mut Vec<u8>, rng: &mut R, len: usize)1292 fn fill_rand_len<R: rand::Rng>(vec: &mut Vec<u8>, rng: &mut R, len: usize) {
1293     for _ in 0..len {
1294         vec.push(rng.gen());
1295     }
1296 }
1297 
prefixed_data<'i>(input_with_prefix: &'i mut String, prefix_len: usize, data: &str) -> &'i str1298 fn prefixed_data<'i>(input_with_prefix: &'i mut String, prefix_len: usize, data: &str) -> &'i str {
1299     input_with_prefix.truncate(prefix_len);
1300     input_with_prefix.push_str(data);
1301     input_with_prefix.as_str()
1302 }
1303 
1304 /// A wrapper to make using engines in rstest fixtures easier.
1305 /// The functions don't need to be instance methods, but rstest does seem
1306 /// to want an instance, so instances are passed to test functions and then ignored.
1307 trait EngineWrapper {
1308     type Engine: Engine;
1309 
1310     /// Return an engine configured for RFC standard base64
standard() -> Self::Engine1311     fn standard() -> Self::Engine;
1312 
1313     /// Return an engine configured for RFC standard base64, except with no padding appended on
1314     /// encode, and required no padding on decode.
standard_unpadded() -> Self::Engine1315     fn standard_unpadded() -> Self::Engine;
1316 
1317     /// Return an engine configured for RFC standard alphabet with the provided encode and decode
1318     /// pad settings
standard_with_pad_mode(encode_pad: bool, decode_pad_mode: DecodePaddingMode) -> Self::Engine1319     fn standard_with_pad_mode(encode_pad: bool, decode_pad_mode: DecodePaddingMode)
1320         -> Self::Engine;
1321 
1322     /// Return an engine configured for RFC standard base64 that allows invalid trailing bits
standard_allow_trailing_bits() -> Self::Engine1323     fn standard_allow_trailing_bits() -> Self::Engine;
1324 
1325     /// Return an engine configured with a randomized alphabet and config
random<R: rand::Rng>(rng: &mut R) -> Self::Engine1326     fn random<R: rand::Rng>(rng: &mut R) -> Self::Engine;
1327 
1328     /// Return an engine configured with the specified alphabet and randomized config
random_alphabet<R: rand::Rng>(rng: &mut R, alphabet: &Alphabet) -> Self::Engine1329     fn random_alphabet<R: rand::Rng>(rng: &mut R, alphabet: &Alphabet) -> Self::Engine;
1330 }
1331 
1332 struct GeneralPurposeWrapper {}
1333 
1334 impl EngineWrapper for GeneralPurposeWrapper {
1335     type Engine = general_purpose::GeneralPurpose;
1336 
standard() -> Self::Engine1337     fn standard() -> Self::Engine {
1338         general_purpose::GeneralPurpose::new(&STANDARD, general_purpose::PAD)
1339     }
1340 
standard_unpadded() -> Self::Engine1341     fn standard_unpadded() -> Self::Engine {
1342         general_purpose::GeneralPurpose::new(&STANDARD, general_purpose::NO_PAD)
1343     }
1344 
standard_with_pad_mode( encode_pad: bool, decode_pad_mode: DecodePaddingMode, ) -> Self::Engine1345     fn standard_with_pad_mode(
1346         encode_pad: bool,
1347         decode_pad_mode: DecodePaddingMode,
1348     ) -> Self::Engine {
1349         general_purpose::GeneralPurpose::new(
1350             &STANDARD,
1351             general_purpose::GeneralPurposeConfig::new()
1352                 .with_encode_padding(encode_pad)
1353                 .with_decode_padding_mode(decode_pad_mode),
1354         )
1355     }
1356 
standard_allow_trailing_bits() -> Self::Engine1357     fn standard_allow_trailing_bits() -> Self::Engine {
1358         general_purpose::GeneralPurpose::new(
1359             &STANDARD,
1360             general_purpose::GeneralPurposeConfig::new().with_decode_allow_trailing_bits(true),
1361         )
1362     }
1363 
random<R: rand::Rng>(rng: &mut R) -> Self::Engine1364     fn random<R: rand::Rng>(rng: &mut R) -> Self::Engine {
1365         let alphabet = random_alphabet(rng);
1366 
1367         Self::random_alphabet(rng, alphabet)
1368     }
1369 
random_alphabet<R: rand::Rng>(rng: &mut R, alphabet: &Alphabet) -> Self::Engine1370     fn random_alphabet<R: rand::Rng>(rng: &mut R, alphabet: &Alphabet) -> Self::Engine {
1371         general_purpose::GeneralPurpose::new(alphabet, random_config(rng))
1372     }
1373 }
1374 
1375 struct NaiveWrapper {}
1376 
1377 impl EngineWrapper for NaiveWrapper {
1378     type Engine = naive::Naive;
1379 
standard() -> Self::Engine1380     fn standard() -> Self::Engine {
1381         naive::Naive::new(
1382             &STANDARD,
1383             naive::NaiveConfig {
1384                 encode_padding: true,
1385                 decode_allow_trailing_bits: false,
1386                 decode_padding_mode: DecodePaddingMode::RequireCanonical,
1387             },
1388         )
1389     }
1390 
standard_unpadded() -> Self::Engine1391     fn standard_unpadded() -> Self::Engine {
1392         naive::Naive::new(
1393             &STANDARD,
1394             naive::NaiveConfig {
1395                 encode_padding: false,
1396                 decode_allow_trailing_bits: false,
1397                 decode_padding_mode: DecodePaddingMode::RequireNone,
1398             },
1399         )
1400     }
1401 
standard_with_pad_mode( encode_pad: bool, decode_pad_mode: DecodePaddingMode, ) -> Self::Engine1402     fn standard_with_pad_mode(
1403         encode_pad: bool,
1404         decode_pad_mode: DecodePaddingMode,
1405     ) -> Self::Engine {
1406         naive::Naive::new(
1407             &STANDARD,
1408             naive::NaiveConfig {
1409                 encode_padding: encode_pad,
1410                 decode_allow_trailing_bits: false,
1411                 decode_padding_mode: decode_pad_mode,
1412             },
1413         )
1414     }
1415 
standard_allow_trailing_bits() -> Self::Engine1416     fn standard_allow_trailing_bits() -> Self::Engine {
1417         naive::Naive::new(
1418             &STANDARD,
1419             naive::NaiveConfig {
1420                 encode_padding: true,
1421                 decode_allow_trailing_bits: true,
1422                 decode_padding_mode: DecodePaddingMode::RequireCanonical,
1423             },
1424         )
1425     }
1426 
random<R: rand::Rng>(rng: &mut R) -> Self::Engine1427     fn random<R: rand::Rng>(rng: &mut R) -> Self::Engine {
1428         let alphabet = random_alphabet(rng);
1429 
1430         Self::random_alphabet(rng, alphabet)
1431     }
1432 
random_alphabet<R: rand::Rng>(rng: &mut R, alphabet: &Alphabet) -> Self::Engine1433     fn random_alphabet<R: rand::Rng>(rng: &mut R, alphabet: &Alphabet) -> Self::Engine {
1434         let mode = rng.gen();
1435 
1436         let config = naive::NaiveConfig {
1437             encode_padding: match mode {
1438                 DecodePaddingMode::Indifferent => rng.gen(),
1439                 DecodePaddingMode::RequireCanonical => true,
1440                 DecodePaddingMode::RequireNone => false,
1441             },
1442             decode_allow_trailing_bits: rng.gen(),
1443             decode_padding_mode: mode,
1444         };
1445 
1446         naive::Naive::new(alphabet, config)
1447     }
1448 }
1449 
1450 /// A pseudo-Engine that routes all decoding through [DecoderReader]
1451 struct DecoderReaderEngine<E: Engine> {
1452     engine: E,
1453 }
1454 
1455 impl<E: Engine> From<E> for DecoderReaderEngine<E> {
from(value: E) -> Self1456     fn from(value: E) -> Self {
1457         Self { engine: value }
1458     }
1459 }
1460 
1461 impl<E: Engine> Engine for DecoderReaderEngine<E> {
1462     type Config = E::Config;
1463     type DecodeEstimate = E::DecodeEstimate;
1464 
internal_encode(&self, input: &[u8], output: &mut [u8]) -> usize1465     fn internal_encode(&self, input: &[u8], output: &mut [u8]) -> usize {
1466         self.engine.internal_encode(input, output)
1467     }
1468 
internal_decoded_len_estimate(&self, input_len: usize) -> Self::DecodeEstimate1469     fn internal_decoded_len_estimate(&self, input_len: usize) -> Self::DecodeEstimate {
1470         self.engine.internal_decoded_len_estimate(input_len)
1471     }
1472 
internal_decode( &self, input: &[u8], output: &mut [u8], decode_estimate: Self::DecodeEstimate, ) -> Result<DecodeMetadata, DecodeSliceError>1473     fn internal_decode(
1474         &self,
1475         input: &[u8],
1476         output: &mut [u8],
1477         decode_estimate: Self::DecodeEstimate,
1478     ) -> Result<DecodeMetadata, DecodeSliceError> {
1479         let mut reader = DecoderReader::new(input, &self.engine);
1480         let mut buf = vec![0; input.len()];
1481         // to avoid effects like not detecting invalid length due to progressively growing
1482         // the output buffer in read_to_end etc, read into a big enough buffer in one go
1483         // to make behavior more consistent with normal engines
1484         let _ = reader
1485             .read(&mut buf)
1486             .and_then(|len| {
1487                 buf.truncate(len);
1488                 // make sure we got everything
1489                 reader.read_to_end(&mut buf)
1490             })
1491             .map_err(|io_error| {
1492                 *io_error
1493                     .into_inner()
1494                     .and_then(|inner| inner.downcast::<DecodeError>().ok())
1495                     .unwrap()
1496             })?;
1497         if output.len() < buf.len() {
1498             return Err(DecodeSliceError::OutputSliceTooSmall);
1499         }
1500         output[..buf.len()].copy_from_slice(&buf);
1501         Ok(DecodeMetadata::new(
1502             buf.len(),
1503             input
1504                 .iter()
1505                 .enumerate()
1506                 .filter(|(_offset, byte)| **byte == PAD_BYTE)
1507                 .map(|(offset, _byte)| offset)
1508                 .next(),
1509         ))
1510     }
1511 
config(&self) -> &Self::Config1512     fn config(&self) -> &Self::Config {
1513         self.engine.config()
1514     }
1515 }
1516 
1517 struct DecoderReaderEngineWrapper {}
1518 
1519 impl EngineWrapper for DecoderReaderEngineWrapper {
1520     type Engine = DecoderReaderEngine<general_purpose::GeneralPurpose>;
1521 
standard() -> Self::Engine1522     fn standard() -> Self::Engine {
1523         GeneralPurposeWrapper::standard().into()
1524     }
1525 
standard_unpadded() -> Self::Engine1526     fn standard_unpadded() -> Self::Engine {
1527         GeneralPurposeWrapper::standard_unpadded().into()
1528     }
1529 
standard_with_pad_mode( encode_pad: bool, decode_pad_mode: DecodePaddingMode, ) -> Self::Engine1530     fn standard_with_pad_mode(
1531         encode_pad: bool,
1532         decode_pad_mode: DecodePaddingMode,
1533     ) -> Self::Engine {
1534         GeneralPurposeWrapper::standard_with_pad_mode(encode_pad, decode_pad_mode).into()
1535     }
1536 
standard_allow_trailing_bits() -> Self::Engine1537     fn standard_allow_trailing_bits() -> Self::Engine {
1538         GeneralPurposeWrapper::standard_allow_trailing_bits().into()
1539     }
1540 
random<R: rand::Rng>(rng: &mut R) -> Self::Engine1541     fn random<R: rand::Rng>(rng: &mut R) -> Self::Engine {
1542         GeneralPurposeWrapper::random(rng).into()
1543     }
1544 
random_alphabet<R: rand::Rng>(rng: &mut R, alphabet: &Alphabet) -> Self::Engine1545     fn random_alphabet<R: rand::Rng>(rng: &mut R, alphabet: &Alphabet) -> Self::Engine {
1546         GeneralPurposeWrapper::random_alphabet(rng, alphabet).into()
1547     }
1548 }
1549 
seeded_rng() -> impl rand::Rng1550 fn seeded_rng() -> impl rand::Rng {
1551     rngs::SmallRng::from_entropy()
1552 }
1553 
all_pad_modes() -> Vec<DecodePaddingMode>1554 fn all_pad_modes() -> Vec<DecodePaddingMode> {
1555     vec![
1556         DecodePaddingMode::Indifferent,
1557         DecodePaddingMode::RequireCanonical,
1558         DecodePaddingMode::RequireNone,
1559     ]
1560 }
1561 
pad_modes_allowing_padding() -> Vec<DecodePaddingMode>1562 fn pad_modes_allowing_padding() -> Vec<DecodePaddingMode> {
1563     vec![
1564         DecodePaddingMode::Indifferent,
1565         DecodePaddingMode::RequireCanonical,
1566     ]
1567 }
1568 
assert_all_suffixes_ok<E: Engine>(engine: E, suffixes: Vec<&str>)1569 fn assert_all_suffixes_ok<E: Engine>(engine: E, suffixes: Vec<&str>) {
1570     for num_prefix_quads in 0..256 {
1571         for &suffix in suffixes.iter() {
1572             let mut encoded = "AAAA".repeat(num_prefix_quads);
1573             encoded.push_str(suffix);
1574 
1575             let res = &engine.decode(&encoded);
1576             assert!(res.is_ok());
1577         }
1578     }
1579 }
1580