1 use std::io::{Cursor, Write};
2 use std::{cmp, io, str};
3 
4 use rand::Rng;
5 
6 use crate::{
7     alphabet::{STANDARD, URL_SAFE},
8     engine::{
9         general_purpose::{GeneralPurpose, NO_PAD, PAD},
10         Engine,
11     },
12     tests::random_engine,
13 };
14 
15 use super::EncoderWriter;
16 
17 const URL_SAFE_ENGINE: GeneralPurpose = GeneralPurpose::new(&URL_SAFE, PAD);
18 const NO_PAD_ENGINE: GeneralPurpose = GeneralPurpose::new(&STANDARD, NO_PAD);
19 
20 #[test]
encode_three_bytes()21 fn encode_three_bytes() {
22     let mut c = Cursor::new(Vec::new());
23     {
24         let mut enc = EncoderWriter::new(&mut c, &URL_SAFE_ENGINE);
25 
26         let sz = enc.write(b"abc").unwrap();
27         assert_eq!(sz, 3);
28     }
29     assert_eq!(&c.get_ref()[..], URL_SAFE_ENGINE.encode("abc").as_bytes());
30 }
31 
32 #[test]
encode_nine_bytes_two_writes()33 fn encode_nine_bytes_two_writes() {
34     let mut c = Cursor::new(Vec::new());
35     {
36         let mut enc = EncoderWriter::new(&mut c, &URL_SAFE_ENGINE);
37 
38         let sz = enc.write(b"abcdef").unwrap();
39         assert_eq!(sz, 6);
40         let sz = enc.write(b"ghi").unwrap();
41         assert_eq!(sz, 3);
42     }
43     assert_eq!(
44         &c.get_ref()[..],
45         URL_SAFE_ENGINE.encode("abcdefghi").as_bytes()
46     );
47 }
48 
49 #[test]
encode_one_then_two_bytes()50 fn encode_one_then_two_bytes() {
51     let mut c = Cursor::new(Vec::new());
52     {
53         let mut enc = EncoderWriter::new(&mut c, &URL_SAFE_ENGINE);
54 
55         let sz = enc.write(b"a").unwrap();
56         assert_eq!(sz, 1);
57         let sz = enc.write(b"bc").unwrap();
58         assert_eq!(sz, 2);
59     }
60     assert_eq!(&c.get_ref()[..], URL_SAFE_ENGINE.encode("abc").as_bytes());
61 }
62 
63 #[test]
encode_one_then_five_bytes()64 fn encode_one_then_five_bytes() {
65     let mut c = Cursor::new(Vec::new());
66     {
67         let mut enc = EncoderWriter::new(&mut c, &URL_SAFE_ENGINE);
68 
69         let sz = enc.write(b"a").unwrap();
70         assert_eq!(sz, 1);
71         let sz = enc.write(b"bcdef").unwrap();
72         assert_eq!(sz, 5);
73     }
74     assert_eq!(
75         &c.get_ref()[..],
76         URL_SAFE_ENGINE.encode("abcdef").as_bytes()
77     );
78 }
79 
80 #[test]
encode_1_2_3_bytes()81 fn encode_1_2_3_bytes() {
82     let mut c = Cursor::new(Vec::new());
83     {
84         let mut enc = EncoderWriter::new(&mut c, &URL_SAFE_ENGINE);
85 
86         let sz = enc.write(b"a").unwrap();
87         assert_eq!(sz, 1);
88         let sz = enc.write(b"bc").unwrap();
89         assert_eq!(sz, 2);
90         let sz = enc.write(b"def").unwrap();
91         assert_eq!(sz, 3);
92     }
93     assert_eq!(
94         &c.get_ref()[..],
95         URL_SAFE_ENGINE.encode("abcdef").as_bytes()
96     );
97 }
98 
99 #[test]
encode_with_padding()100 fn encode_with_padding() {
101     let mut c = Cursor::new(Vec::new());
102     {
103         let mut enc = EncoderWriter::new(&mut c, &URL_SAFE_ENGINE);
104 
105         enc.write_all(b"abcd").unwrap();
106 
107         enc.flush().unwrap();
108     }
109     assert_eq!(&c.get_ref()[..], URL_SAFE_ENGINE.encode("abcd").as_bytes());
110 }
111 
112 #[test]
encode_with_padding_multiple_writes()113 fn encode_with_padding_multiple_writes() {
114     let mut c = Cursor::new(Vec::new());
115     {
116         let mut enc = EncoderWriter::new(&mut c, &URL_SAFE_ENGINE);
117 
118         assert_eq!(1, enc.write(b"a").unwrap());
119         assert_eq!(2, enc.write(b"bc").unwrap());
120         assert_eq!(3, enc.write(b"def").unwrap());
121         assert_eq!(1, enc.write(b"g").unwrap());
122 
123         enc.flush().unwrap();
124     }
125     assert_eq!(
126         &c.get_ref()[..],
127         URL_SAFE_ENGINE.encode("abcdefg").as_bytes()
128     );
129 }
130 
131 #[test]
finish_writes_extra_byte()132 fn finish_writes_extra_byte() {
133     let mut c = Cursor::new(Vec::new());
134     {
135         let mut enc = EncoderWriter::new(&mut c, &URL_SAFE_ENGINE);
136 
137         assert_eq!(6, enc.write(b"abcdef").unwrap());
138 
139         // will be in extra
140         assert_eq!(1, enc.write(b"g").unwrap());
141 
142         // 1 trailing byte = 2 encoded chars
143         let _ = enc.finish().unwrap();
144     }
145     assert_eq!(
146         &c.get_ref()[..],
147         URL_SAFE_ENGINE.encode("abcdefg").as_bytes()
148     );
149 }
150 
151 #[test]
write_partial_chunk_encodes_partial_chunk()152 fn write_partial_chunk_encodes_partial_chunk() {
153     let mut c = Cursor::new(Vec::new());
154     {
155         let mut enc = EncoderWriter::new(&mut c, &NO_PAD_ENGINE);
156 
157         // nothing encoded yet
158         assert_eq!(2, enc.write(b"ab").unwrap());
159         // encoded here
160         let _ = enc.finish().unwrap();
161     }
162     assert_eq!(&c.get_ref()[..], NO_PAD_ENGINE.encode("ab").as_bytes());
163     assert_eq!(3, c.get_ref().len());
164 }
165 
166 #[test]
write_1_chunk_encodes_complete_chunk()167 fn write_1_chunk_encodes_complete_chunk() {
168     let mut c = Cursor::new(Vec::new());
169     {
170         let mut enc = EncoderWriter::new(&mut c, &NO_PAD_ENGINE);
171 
172         assert_eq!(3, enc.write(b"abc").unwrap());
173         let _ = enc.finish().unwrap();
174     }
175     assert_eq!(&c.get_ref()[..], NO_PAD_ENGINE.encode("abc").as_bytes());
176     assert_eq!(4, c.get_ref().len());
177 }
178 
179 #[test]
write_1_chunk_and_partial_encodes_only_complete_chunk()180 fn write_1_chunk_and_partial_encodes_only_complete_chunk() {
181     let mut c = Cursor::new(Vec::new());
182     {
183         let mut enc = EncoderWriter::new(&mut c, &NO_PAD_ENGINE);
184 
185         // "d" not consumed since it's not a full chunk
186         assert_eq!(3, enc.write(b"abcd").unwrap());
187         let _ = enc.finish().unwrap();
188     }
189     assert_eq!(&c.get_ref()[..], NO_PAD_ENGINE.encode("abc").as_bytes());
190     assert_eq!(4, c.get_ref().len());
191 }
192 
193 #[test]
write_2_partials_to_exactly_complete_chunk_encodes_complete_chunk()194 fn write_2_partials_to_exactly_complete_chunk_encodes_complete_chunk() {
195     let mut c = Cursor::new(Vec::new());
196     {
197         let mut enc = EncoderWriter::new(&mut c, &NO_PAD_ENGINE);
198 
199         assert_eq!(1, enc.write(b"a").unwrap());
200         assert_eq!(2, enc.write(b"bc").unwrap());
201         let _ = enc.finish().unwrap();
202     }
203     assert_eq!(&c.get_ref()[..], NO_PAD_ENGINE.encode("abc").as_bytes());
204     assert_eq!(4, c.get_ref().len());
205 }
206 
207 #[test]
write_partial_then_enough_to_complete_chunk_but_not_complete_another_chunk_encodes_complete_chunk_without_consuming_remaining( )208 fn write_partial_then_enough_to_complete_chunk_but_not_complete_another_chunk_encodes_complete_chunk_without_consuming_remaining(
209 ) {
210     let mut c = Cursor::new(Vec::new());
211     {
212         let mut enc = EncoderWriter::new(&mut c, &NO_PAD_ENGINE);
213 
214         assert_eq!(1, enc.write(b"a").unwrap());
215         // doesn't consume "d"
216         assert_eq!(2, enc.write(b"bcd").unwrap());
217         let _ = enc.finish().unwrap();
218     }
219     assert_eq!(&c.get_ref()[..], NO_PAD_ENGINE.encode("abc").as_bytes());
220     assert_eq!(4, c.get_ref().len());
221 }
222 
223 #[test]
write_partial_then_enough_to_complete_chunk_and_another_chunk_encodes_complete_chunks()224 fn write_partial_then_enough_to_complete_chunk_and_another_chunk_encodes_complete_chunks() {
225     let mut c = Cursor::new(Vec::new());
226     {
227         let mut enc = EncoderWriter::new(&mut c, &NO_PAD_ENGINE);
228 
229         assert_eq!(1, enc.write(b"a").unwrap());
230         // completes partial chunk, and another chunk
231         assert_eq!(5, enc.write(b"bcdef").unwrap());
232         let _ = enc.finish().unwrap();
233     }
234     assert_eq!(&c.get_ref()[..], NO_PAD_ENGINE.encode("abcdef").as_bytes());
235     assert_eq!(8, c.get_ref().len());
236 }
237 
238 #[test]
write_partial_then_enough_to_complete_chunk_and_another_chunk_and_another_partial_chunk_encodes_only_complete_chunks( )239 fn write_partial_then_enough_to_complete_chunk_and_another_chunk_and_another_partial_chunk_encodes_only_complete_chunks(
240 ) {
241     let mut c = Cursor::new(Vec::new());
242     {
243         let mut enc = EncoderWriter::new(&mut c, &NO_PAD_ENGINE);
244 
245         assert_eq!(1, enc.write(b"a").unwrap());
246         // completes partial chunk, and another chunk, with one more partial chunk that's not
247         // consumed
248         assert_eq!(5, enc.write(b"bcdefe").unwrap());
249         let _ = enc.finish().unwrap();
250     }
251     assert_eq!(&c.get_ref()[..], NO_PAD_ENGINE.encode("abcdef").as_bytes());
252     assert_eq!(8, c.get_ref().len());
253 }
254 
255 #[test]
drop_calls_finish_for_you()256 fn drop_calls_finish_for_you() {
257     let mut c = Cursor::new(Vec::new());
258     {
259         let mut enc = EncoderWriter::new(&mut c, &NO_PAD_ENGINE);
260         assert_eq!(1, enc.write(b"a").unwrap());
261     }
262     assert_eq!(&c.get_ref()[..], NO_PAD_ENGINE.encode("a").as_bytes());
263     assert_eq!(2, c.get_ref().len());
264 }
265 
266 #[test]
every_possible_split_of_input()267 fn every_possible_split_of_input() {
268     let mut rng = rand::thread_rng();
269     let mut orig_data = Vec::<u8>::new();
270     let mut stream_encoded = Vec::<u8>::new();
271     let mut normal_encoded = String::new();
272 
273     let size = 5_000;
274 
275     for i in 0..size {
276         orig_data.clear();
277         stream_encoded.clear();
278         normal_encoded.clear();
279 
280         for _ in 0..size {
281             orig_data.push(rng.gen());
282         }
283 
284         let engine = random_engine(&mut rng);
285         engine.encode_string(&orig_data, &mut normal_encoded);
286 
287         {
288             let mut stream_encoder = EncoderWriter::new(&mut stream_encoded, &engine);
289             // Write the first i bytes, then the rest
290             stream_encoder.write_all(&orig_data[0..i]).unwrap();
291             stream_encoder.write_all(&orig_data[i..]).unwrap();
292         }
293 
294         assert_eq!(normal_encoded, str::from_utf8(&stream_encoded).unwrap());
295     }
296 }
297 
298 #[test]
encode_random_config_matches_normal_encode_reasonable_input_len()299 fn encode_random_config_matches_normal_encode_reasonable_input_len() {
300     // choose up to 2 * buf size, so ~half the time it'll use a full buffer
301     do_encode_random_config_matches_normal_encode(super::encoder::BUF_SIZE * 2);
302 }
303 
304 #[test]
encode_random_config_matches_normal_encode_tiny_input_len()305 fn encode_random_config_matches_normal_encode_tiny_input_len() {
306     do_encode_random_config_matches_normal_encode(10);
307 }
308 
309 #[test]
retrying_writes_that_error_with_interrupted_works()310 fn retrying_writes_that_error_with_interrupted_works() {
311     let mut rng = rand::thread_rng();
312     let mut orig_data = Vec::<u8>::new();
313     let mut stream_encoded = Vec::<u8>::new();
314     let mut normal_encoded = String::new();
315 
316     for _ in 0..1_000 {
317         orig_data.clear();
318         stream_encoded.clear();
319         normal_encoded.clear();
320 
321         let orig_len: usize = rng.gen_range(100..20_000);
322         for _ in 0..orig_len {
323             orig_data.push(rng.gen());
324         }
325 
326         // encode the normal way
327         let engine = random_engine(&mut rng);
328         engine.encode_string(&orig_data, &mut normal_encoded);
329 
330         // encode via the stream encoder
331         {
332             let mut interrupt_rng = rand::thread_rng();
333             let mut interrupting_writer = InterruptingWriter {
334                 w: &mut stream_encoded,
335                 rng: &mut interrupt_rng,
336                 fraction: 0.8,
337             };
338 
339             let mut stream_encoder = EncoderWriter::new(&mut interrupting_writer, &engine);
340             let mut bytes_consumed = 0;
341             while bytes_consumed < orig_len {
342                 // use short inputs since we want to use `extra` a lot as that's what needs rollback
343                 // when errors occur
344                 let input_len: usize = cmp::min(rng.gen_range(0..10), orig_len - bytes_consumed);
345 
346                 retry_interrupted_write_all(
347                     &mut stream_encoder,
348                     &orig_data[bytes_consumed..bytes_consumed + input_len],
349                 )
350                 .unwrap();
351 
352                 bytes_consumed += input_len;
353             }
354 
355             loop {
356                 let res = stream_encoder.finish();
357                 match res {
358                     Ok(_) => break,
359                     Err(e) => match e.kind() {
360                         io::ErrorKind::Interrupted => continue,
361                         _ => panic!("{:?}", e), // bail
362                     },
363                 }
364             }
365 
366             assert_eq!(orig_len, bytes_consumed);
367         }
368 
369         assert_eq!(normal_encoded, str::from_utf8(&stream_encoded).unwrap());
370     }
371 }
372 
373 #[test]
writes_that_only_write_part_of_input_and_sometimes_interrupt_produce_correct_encoded_data()374 fn writes_that_only_write_part_of_input_and_sometimes_interrupt_produce_correct_encoded_data() {
375     let mut rng = rand::thread_rng();
376     let mut orig_data = Vec::<u8>::new();
377     let mut stream_encoded = Vec::<u8>::new();
378     let mut normal_encoded = String::new();
379 
380     for _ in 0..1_000 {
381         orig_data.clear();
382         stream_encoded.clear();
383         normal_encoded.clear();
384 
385         let orig_len: usize = rng.gen_range(100..20_000);
386         for _ in 0..orig_len {
387             orig_data.push(rng.gen());
388         }
389 
390         // encode the normal way
391         let engine = random_engine(&mut rng);
392         engine.encode_string(&orig_data, &mut normal_encoded);
393 
394         // encode via the stream encoder
395         {
396             let mut partial_rng = rand::thread_rng();
397             let mut partial_writer = PartialInterruptingWriter {
398                 w: &mut stream_encoded,
399                 rng: &mut partial_rng,
400                 full_input_fraction: 0.1,
401                 no_interrupt_fraction: 0.1,
402             };
403 
404             let mut stream_encoder = EncoderWriter::new(&mut partial_writer, &engine);
405             let mut bytes_consumed = 0;
406             while bytes_consumed < orig_len {
407                 // use at most medium-length inputs to exercise retry logic more aggressively
408                 let input_len: usize = cmp::min(rng.gen_range(0..100), orig_len - bytes_consumed);
409 
410                 let res =
411                     stream_encoder.write(&orig_data[bytes_consumed..bytes_consumed + input_len]);
412 
413                 // retry on interrupt
414                 match res {
415                     Ok(len) => bytes_consumed += len,
416                     Err(e) => match e.kind() {
417                         io::ErrorKind::Interrupted => continue,
418                         _ => {
419                             panic!("should not see other errors");
420                         }
421                     },
422                 }
423             }
424 
425             let _ = stream_encoder.finish().unwrap();
426 
427             assert_eq!(orig_len, bytes_consumed);
428         }
429 
430         assert_eq!(normal_encoded, str::from_utf8(&stream_encoded).unwrap());
431     }
432 }
433 
434 /// Retry writes until all the data is written or an error that isn't Interrupted is returned.
retry_interrupted_write_all<W: Write>(w: &mut W, buf: &[u8]) -> io::Result<()>435 fn retry_interrupted_write_all<W: Write>(w: &mut W, buf: &[u8]) -> io::Result<()> {
436     let mut bytes_consumed = 0;
437 
438     while bytes_consumed < buf.len() {
439         let res = w.write(&buf[bytes_consumed..]);
440 
441         match res {
442             Ok(len) => bytes_consumed += len,
443             Err(e) => match e.kind() {
444                 io::ErrorKind::Interrupted => continue,
445                 _ => return Err(e),
446             },
447         }
448     }
449 
450     Ok(())
451 }
452 
do_encode_random_config_matches_normal_encode(max_input_len: usize)453 fn do_encode_random_config_matches_normal_encode(max_input_len: usize) {
454     let mut rng = rand::thread_rng();
455     let mut orig_data = Vec::<u8>::new();
456     let mut stream_encoded = Vec::<u8>::new();
457     let mut normal_encoded = String::new();
458 
459     for _ in 0..1_000 {
460         orig_data.clear();
461         stream_encoded.clear();
462         normal_encoded.clear();
463 
464         let orig_len: usize = rng.gen_range(100..20_000);
465         for _ in 0..orig_len {
466             orig_data.push(rng.gen());
467         }
468 
469         // encode the normal way
470         let engine = random_engine(&mut rng);
471         engine.encode_string(&orig_data, &mut normal_encoded);
472 
473         // encode via the stream encoder
474         {
475             let mut stream_encoder = EncoderWriter::new(&mut stream_encoded, &engine);
476             let mut bytes_consumed = 0;
477             while bytes_consumed < orig_len {
478                 let input_len: usize =
479                     cmp::min(rng.gen_range(0..max_input_len), orig_len - bytes_consumed);
480 
481                 // write a little bit of the data
482                 stream_encoder
483                     .write_all(&orig_data[bytes_consumed..bytes_consumed + input_len])
484                     .unwrap();
485 
486                 bytes_consumed += input_len;
487             }
488 
489             let _ = stream_encoder.finish().unwrap();
490 
491             assert_eq!(orig_len, bytes_consumed);
492         }
493 
494         assert_eq!(normal_encoded, str::from_utf8(&stream_encoded).unwrap());
495     }
496 }
497 
498 /// A `Write` implementation that returns Interrupted some fraction of the time, randomly.
499 struct InterruptingWriter<'a, W: 'a + Write, R: 'a + Rng> {
500     w: &'a mut W,
501     rng: &'a mut R,
502     /// In [0, 1]. If a random number in [0, 1] is  `<= threshold`, `Write` methods will return
503     /// an `Interrupted` error
504     fraction: f64,
505 }
506 
507 impl<'a, W: Write, R: Rng> Write for InterruptingWriter<'a, W, R> {
write(&mut self, buf: &[u8]) -> io::Result<usize>508     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
509         if self.rng.gen_range(0.0..1.0) <= self.fraction {
510             return Err(io::Error::new(io::ErrorKind::Interrupted, "interrupted"));
511         }
512 
513         self.w.write(buf)
514     }
515 
flush(&mut self) -> io::Result<()>516     fn flush(&mut self) -> io::Result<()> {
517         if self.rng.gen_range(0.0..1.0) <= self.fraction {
518             return Err(io::Error::new(io::ErrorKind::Interrupted, "interrupted"));
519         }
520 
521         self.w.flush()
522     }
523 }
524 
525 /// A `Write` implementation that sometimes will only write part of its input.
526 struct PartialInterruptingWriter<'a, W: 'a + Write, R: 'a + Rng> {
527     w: &'a mut W,
528     rng: &'a mut R,
529     /// In [0, 1]. If a random number in [0, 1] is  `<= threshold`, `write()` will write all its
530     /// input. Otherwise, it will write a random substring
531     full_input_fraction: f64,
532     no_interrupt_fraction: f64,
533 }
534 
535 impl<'a, W: Write, R: Rng> Write for PartialInterruptingWriter<'a, W, R> {
write(&mut self, buf: &[u8]) -> io::Result<usize>536     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
537         if self.rng.gen_range(0.0..1.0) > self.no_interrupt_fraction {
538             return Err(io::Error::new(io::ErrorKind::Interrupted, "interrupted"));
539         }
540 
541         if self.rng.gen_range(0.0..1.0) <= self.full_input_fraction || buf.is_empty() {
542             // pass through the buf untouched
543             self.w.write(buf)
544         } else {
545             // only use a prefix of it
546             self.w
547                 .write(&buf[0..(self.rng.gen_range(0..(buf.len() - 1)))])
548         }
549     }
550 
flush(&mut self) -> io::Result<()>551     fn flush(&mut self) -> io::Result<()> {
552         self.w.flush()
553     }
554 }
555