1 use std::io::{Cursor, Write};
2 use std::{cmp, io, str};
3
4 use rand::Rng;
5
6 use crate::{
7 alphabet::{STANDARD, URL_SAFE},
8 engine::{
9 general_purpose::{GeneralPurpose, NO_PAD, PAD},
10 Engine,
11 },
12 tests::random_engine,
13 };
14
15 use super::EncoderWriter;
16
17 const URL_SAFE_ENGINE: GeneralPurpose = GeneralPurpose::new(&URL_SAFE, PAD);
18 const NO_PAD_ENGINE: GeneralPurpose = GeneralPurpose::new(&STANDARD, NO_PAD);
19
20 #[test]
encode_three_bytes()21 fn encode_three_bytes() {
22 let mut c = Cursor::new(Vec::new());
23 {
24 let mut enc = EncoderWriter::new(&mut c, &URL_SAFE_ENGINE);
25
26 let sz = enc.write(b"abc").unwrap();
27 assert_eq!(sz, 3);
28 }
29 assert_eq!(&c.get_ref()[..], URL_SAFE_ENGINE.encode("abc").as_bytes());
30 }
31
32 #[test]
encode_nine_bytes_two_writes()33 fn encode_nine_bytes_two_writes() {
34 let mut c = Cursor::new(Vec::new());
35 {
36 let mut enc = EncoderWriter::new(&mut c, &URL_SAFE_ENGINE);
37
38 let sz = enc.write(b"abcdef").unwrap();
39 assert_eq!(sz, 6);
40 let sz = enc.write(b"ghi").unwrap();
41 assert_eq!(sz, 3);
42 }
43 assert_eq!(
44 &c.get_ref()[..],
45 URL_SAFE_ENGINE.encode("abcdefghi").as_bytes()
46 );
47 }
48
49 #[test]
encode_one_then_two_bytes()50 fn encode_one_then_two_bytes() {
51 let mut c = Cursor::new(Vec::new());
52 {
53 let mut enc = EncoderWriter::new(&mut c, &URL_SAFE_ENGINE);
54
55 let sz = enc.write(b"a").unwrap();
56 assert_eq!(sz, 1);
57 let sz = enc.write(b"bc").unwrap();
58 assert_eq!(sz, 2);
59 }
60 assert_eq!(&c.get_ref()[..], URL_SAFE_ENGINE.encode("abc").as_bytes());
61 }
62
63 #[test]
encode_one_then_five_bytes()64 fn encode_one_then_five_bytes() {
65 let mut c = Cursor::new(Vec::new());
66 {
67 let mut enc = EncoderWriter::new(&mut c, &URL_SAFE_ENGINE);
68
69 let sz = enc.write(b"a").unwrap();
70 assert_eq!(sz, 1);
71 let sz = enc.write(b"bcdef").unwrap();
72 assert_eq!(sz, 5);
73 }
74 assert_eq!(
75 &c.get_ref()[..],
76 URL_SAFE_ENGINE.encode("abcdef").as_bytes()
77 );
78 }
79
80 #[test]
encode_1_2_3_bytes()81 fn encode_1_2_3_bytes() {
82 let mut c = Cursor::new(Vec::new());
83 {
84 let mut enc = EncoderWriter::new(&mut c, &URL_SAFE_ENGINE);
85
86 let sz = enc.write(b"a").unwrap();
87 assert_eq!(sz, 1);
88 let sz = enc.write(b"bc").unwrap();
89 assert_eq!(sz, 2);
90 let sz = enc.write(b"def").unwrap();
91 assert_eq!(sz, 3);
92 }
93 assert_eq!(
94 &c.get_ref()[..],
95 URL_SAFE_ENGINE.encode("abcdef").as_bytes()
96 );
97 }
98
99 #[test]
encode_with_padding()100 fn encode_with_padding() {
101 let mut c = Cursor::new(Vec::new());
102 {
103 let mut enc = EncoderWriter::new(&mut c, &URL_SAFE_ENGINE);
104
105 enc.write_all(b"abcd").unwrap();
106
107 enc.flush().unwrap();
108 }
109 assert_eq!(&c.get_ref()[..], URL_SAFE_ENGINE.encode("abcd").as_bytes());
110 }
111
112 #[test]
encode_with_padding_multiple_writes()113 fn encode_with_padding_multiple_writes() {
114 let mut c = Cursor::new(Vec::new());
115 {
116 let mut enc = EncoderWriter::new(&mut c, &URL_SAFE_ENGINE);
117
118 assert_eq!(1, enc.write(b"a").unwrap());
119 assert_eq!(2, enc.write(b"bc").unwrap());
120 assert_eq!(3, enc.write(b"def").unwrap());
121 assert_eq!(1, enc.write(b"g").unwrap());
122
123 enc.flush().unwrap();
124 }
125 assert_eq!(
126 &c.get_ref()[..],
127 URL_SAFE_ENGINE.encode("abcdefg").as_bytes()
128 );
129 }
130
131 #[test]
finish_writes_extra_byte()132 fn finish_writes_extra_byte() {
133 let mut c = Cursor::new(Vec::new());
134 {
135 let mut enc = EncoderWriter::new(&mut c, &URL_SAFE_ENGINE);
136
137 assert_eq!(6, enc.write(b"abcdef").unwrap());
138
139 // will be in extra
140 assert_eq!(1, enc.write(b"g").unwrap());
141
142 // 1 trailing byte = 2 encoded chars
143 let _ = enc.finish().unwrap();
144 }
145 assert_eq!(
146 &c.get_ref()[..],
147 URL_SAFE_ENGINE.encode("abcdefg").as_bytes()
148 );
149 }
150
151 #[test]
write_partial_chunk_encodes_partial_chunk()152 fn write_partial_chunk_encodes_partial_chunk() {
153 let mut c = Cursor::new(Vec::new());
154 {
155 let mut enc = EncoderWriter::new(&mut c, &NO_PAD_ENGINE);
156
157 // nothing encoded yet
158 assert_eq!(2, enc.write(b"ab").unwrap());
159 // encoded here
160 let _ = enc.finish().unwrap();
161 }
162 assert_eq!(&c.get_ref()[..], NO_PAD_ENGINE.encode("ab").as_bytes());
163 assert_eq!(3, c.get_ref().len());
164 }
165
166 #[test]
write_1_chunk_encodes_complete_chunk()167 fn write_1_chunk_encodes_complete_chunk() {
168 let mut c = Cursor::new(Vec::new());
169 {
170 let mut enc = EncoderWriter::new(&mut c, &NO_PAD_ENGINE);
171
172 assert_eq!(3, enc.write(b"abc").unwrap());
173 let _ = enc.finish().unwrap();
174 }
175 assert_eq!(&c.get_ref()[..], NO_PAD_ENGINE.encode("abc").as_bytes());
176 assert_eq!(4, c.get_ref().len());
177 }
178
179 #[test]
write_1_chunk_and_partial_encodes_only_complete_chunk()180 fn write_1_chunk_and_partial_encodes_only_complete_chunk() {
181 let mut c = Cursor::new(Vec::new());
182 {
183 let mut enc = EncoderWriter::new(&mut c, &NO_PAD_ENGINE);
184
185 // "d" not consumed since it's not a full chunk
186 assert_eq!(3, enc.write(b"abcd").unwrap());
187 let _ = enc.finish().unwrap();
188 }
189 assert_eq!(&c.get_ref()[..], NO_PAD_ENGINE.encode("abc").as_bytes());
190 assert_eq!(4, c.get_ref().len());
191 }
192
193 #[test]
write_2_partials_to_exactly_complete_chunk_encodes_complete_chunk()194 fn write_2_partials_to_exactly_complete_chunk_encodes_complete_chunk() {
195 let mut c = Cursor::new(Vec::new());
196 {
197 let mut enc = EncoderWriter::new(&mut c, &NO_PAD_ENGINE);
198
199 assert_eq!(1, enc.write(b"a").unwrap());
200 assert_eq!(2, enc.write(b"bc").unwrap());
201 let _ = enc.finish().unwrap();
202 }
203 assert_eq!(&c.get_ref()[..], NO_PAD_ENGINE.encode("abc").as_bytes());
204 assert_eq!(4, c.get_ref().len());
205 }
206
207 #[test]
write_partial_then_enough_to_complete_chunk_but_not_complete_another_chunk_encodes_complete_chunk_without_consuming_remaining( )208 fn write_partial_then_enough_to_complete_chunk_but_not_complete_another_chunk_encodes_complete_chunk_without_consuming_remaining(
209 ) {
210 let mut c = Cursor::new(Vec::new());
211 {
212 let mut enc = EncoderWriter::new(&mut c, &NO_PAD_ENGINE);
213
214 assert_eq!(1, enc.write(b"a").unwrap());
215 // doesn't consume "d"
216 assert_eq!(2, enc.write(b"bcd").unwrap());
217 let _ = enc.finish().unwrap();
218 }
219 assert_eq!(&c.get_ref()[..], NO_PAD_ENGINE.encode("abc").as_bytes());
220 assert_eq!(4, c.get_ref().len());
221 }
222
223 #[test]
write_partial_then_enough_to_complete_chunk_and_another_chunk_encodes_complete_chunks()224 fn write_partial_then_enough_to_complete_chunk_and_another_chunk_encodes_complete_chunks() {
225 let mut c = Cursor::new(Vec::new());
226 {
227 let mut enc = EncoderWriter::new(&mut c, &NO_PAD_ENGINE);
228
229 assert_eq!(1, enc.write(b"a").unwrap());
230 // completes partial chunk, and another chunk
231 assert_eq!(5, enc.write(b"bcdef").unwrap());
232 let _ = enc.finish().unwrap();
233 }
234 assert_eq!(&c.get_ref()[..], NO_PAD_ENGINE.encode("abcdef").as_bytes());
235 assert_eq!(8, c.get_ref().len());
236 }
237
238 #[test]
write_partial_then_enough_to_complete_chunk_and_another_chunk_and_another_partial_chunk_encodes_only_complete_chunks( )239 fn write_partial_then_enough_to_complete_chunk_and_another_chunk_and_another_partial_chunk_encodes_only_complete_chunks(
240 ) {
241 let mut c = Cursor::new(Vec::new());
242 {
243 let mut enc = EncoderWriter::new(&mut c, &NO_PAD_ENGINE);
244
245 assert_eq!(1, enc.write(b"a").unwrap());
246 // completes partial chunk, and another chunk, with one more partial chunk that's not
247 // consumed
248 assert_eq!(5, enc.write(b"bcdefe").unwrap());
249 let _ = enc.finish().unwrap();
250 }
251 assert_eq!(&c.get_ref()[..], NO_PAD_ENGINE.encode("abcdef").as_bytes());
252 assert_eq!(8, c.get_ref().len());
253 }
254
255 #[test]
drop_calls_finish_for_you()256 fn drop_calls_finish_for_you() {
257 let mut c = Cursor::new(Vec::new());
258 {
259 let mut enc = EncoderWriter::new(&mut c, &NO_PAD_ENGINE);
260 assert_eq!(1, enc.write(b"a").unwrap());
261 }
262 assert_eq!(&c.get_ref()[..], NO_PAD_ENGINE.encode("a").as_bytes());
263 assert_eq!(2, c.get_ref().len());
264 }
265
266 #[test]
every_possible_split_of_input()267 fn every_possible_split_of_input() {
268 let mut rng = rand::thread_rng();
269 let mut orig_data = Vec::<u8>::new();
270 let mut stream_encoded = Vec::<u8>::new();
271 let mut normal_encoded = String::new();
272
273 let size = 5_000;
274
275 for i in 0..size {
276 orig_data.clear();
277 stream_encoded.clear();
278 normal_encoded.clear();
279
280 for _ in 0..size {
281 orig_data.push(rng.gen());
282 }
283
284 let engine = random_engine(&mut rng);
285 engine.encode_string(&orig_data, &mut normal_encoded);
286
287 {
288 let mut stream_encoder = EncoderWriter::new(&mut stream_encoded, &engine);
289 // Write the first i bytes, then the rest
290 stream_encoder.write_all(&orig_data[0..i]).unwrap();
291 stream_encoder.write_all(&orig_data[i..]).unwrap();
292 }
293
294 assert_eq!(normal_encoded, str::from_utf8(&stream_encoded).unwrap());
295 }
296 }
297
298 #[test]
encode_random_config_matches_normal_encode_reasonable_input_len()299 fn encode_random_config_matches_normal_encode_reasonable_input_len() {
300 // choose up to 2 * buf size, so ~half the time it'll use a full buffer
301 do_encode_random_config_matches_normal_encode(super::encoder::BUF_SIZE * 2);
302 }
303
304 #[test]
encode_random_config_matches_normal_encode_tiny_input_len()305 fn encode_random_config_matches_normal_encode_tiny_input_len() {
306 do_encode_random_config_matches_normal_encode(10);
307 }
308
309 #[test]
retrying_writes_that_error_with_interrupted_works()310 fn retrying_writes_that_error_with_interrupted_works() {
311 let mut rng = rand::thread_rng();
312 let mut orig_data = Vec::<u8>::new();
313 let mut stream_encoded = Vec::<u8>::new();
314 let mut normal_encoded = String::new();
315
316 for _ in 0..1_000 {
317 orig_data.clear();
318 stream_encoded.clear();
319 normal_encoded.clear();
320
321 let orig_len: usize = rng.gen_range(100..20_000);
322 for _ in 0..orig_len {
323 orig_data.push(rng.gen());
324 }
325
326 // encode the normal way
327 let engine = random_engine(&mut rng);
328 engine.encode_string(&orig_data, &mut normal_encoded);
329
330 // encode via the stream encoder
331 {
332 let mut interrupt_rng = rand::thread_rng();
333 let mut interrupting_writer = InterruptingWriter {
334 w: &mut stream_encoded,
335 rng: &mut interrupt_rng,
336 fraction: 0.8,
337 };
338
339 let mut stream_encoder = EncoderWriter::new(&mut interrupting_writer, &engine);
340 let mut bytes_consumed = 0;
341 while bytes_consumed < orig_len {
342 // use short inputs since we want to use `extra` a lot as that's what needs rollback
343 // when errors occur
344 let input_len: usize = cmp::min(rng.gen_range(0..10), orig_len - bytes_consumed);
345
346 retry_interrupted_write_all(
347 &mut stream_encoder,
348 &orig_data[bytes_consumed..bytes_consumed + input_len],
349 )
350 .unwrap();
351
352 bytes_consumed += input_len;
353 }
354
355 loop {
356 let res = stream_encoder.finish();
357 match res {
358 Ok(_) => break,
359 Err(e) => match e.kind() {
360 io::ErrorKind::Interrupted => continue,
361 _ => panic!("{:?}", e), // bail
362 },
363 }
364 }
365
366 assert_eq!(orig_len, bytes_consumed);
367 }
368
369 assert_eq!(normal_encoded, str::from_utf8(&stream_encoded).unwrap());
370 }
371 }
372
373 #[test]
writes_that_only_write_part_of_input_and_sometimes_interrupt_produce_correct_encoded_data()374 fn writes_that_only_write_part_of_input_and_sometimes_interrupt_produce_correct_encoded_data() {
375 let mut rng = rand::thread_rng();
376 let mut orig_data = Vec::<u8>::new();
377 let mut stream_encoded = Vec::<u8>::new();
378 let mut normal_encoded = String::new();
379
380 for _ in 0..1_000 {
381 orig_data.clear();
382 stream_encoded.clear();
383 normal_encoded.clear();
384
385 let orig_len: usize = rng.gen_range(100..20_000);
386 for _ in 0..orig_len {
387 orig_data.push(rng.gen());
388 }
389
390 // encode the normal way
391 let engine = random_engine(&mut rng);
392 engine.encode_string(&orig_data, &mut normal_encoded);
393
394 // encode via the stream encoder
395 {
396 let mut partial_rng = rand::thread_rng();
397 let mut partial_writer = PartialInterruptingWriter {
398 w: &mut stream_encoded,
399 rng: &mut partial_rng,
400 full_input_fraction: 0.1,
401 no_interrupt_fraction: 0.1,
402 };
403
404 let mut stream_encoder = EncoderWriter::new(&mut partial_writer, &engine);
405 let mut bytes_consumed = 0;
406 while bytes_consumed < orig_len {
407 // use at most medium-length inputs to exercise retry logic more aggressively
408 let input_len: usize = cmp::min(rng.gen_range(0..100), orig_len - bytes_consumed);
409
410 let res =
411 stream_encoder.write(&orig_data[bytes_consumed..bytes_consumed + input_len]);
412
413 // retry on interrupt
414 match res {
415 Ok(len) => bytes_consumed += len,
416 Err(e) => match e.kind() {
417 io::ErrorKind::Interrupted => continue,
418 _ => {
419 panic!("should not see other errors");
420 }
421 },
422 }
423 }
424
425 let _ = stream_encoder.finish().unwrap();
426
427 assert_eq!(orig_len, bytes_consumed);
428 }
429
430 assert_eq!(normal_encoded, str::from_utf8(&stream_encoded).unwrap());
431 }
432 }
433
434 /// Retry writes until all the data is written or an error that isn't Interrupted is returned.
retry_interrupted_write_all<W: Write>(w: &mut W, buf: &[u8]) -> io::Result<()>435 fn retry_interrupted_write_all<W: Write>(w: &mut W, buf: &[u8]) -> io::Result<()> {
436 let mut bytes_consumed = 0;
437
438 while bytes_consumed < buf.len() {
439 let res = w.write(&buf[bytes_consumed..]);
440
441 match res {
442 Ok(len) => bytes_consumed += len,
443 Err(e) => match e.kind() {
444 io::ErrorKind::Interrupted => continue,
445 _ => return Err(e),
446 },
447 }
448 }
449
450 Ok(())
451 }
452
do_encode_random_config_matches_normal_encode(max_input_len: usize)453 fn do_encode_random_config_matches_normal_encode(max_input_len: usize) {
454 let mut rng = rand::thread_rng();
455 let mut orig_data = Vec::<u8>::new();
456 let mut stream_encoded = Vec::<u8>::new();
457 let mut normal_encoded = String::new();
458
459 for _ in 0..1_000 {
460 orig_data.clear();
461 stream_encoded.clear();
462 normal_encoded.clear();
463
464 let orig_len: usize = rng.gen_range(100..20_000);
465 for _ in 0..orig_len {
466 orig_data.push(rng.gen());
467 }
468
469 // encode the normal way
470 let engine = random_engine(&mut rng);
471 engine.encode_string(&orig_data, &mut normal_encoded);
472
473 // encode via the stream encoder
474 {
475 let mut stream_encoder = EncoderWriter::new(&mut stream_encoded, &engine);
476 let mut bytes_consumed = 0;
477 while bytes_consumed < orig_len {
478 let input_len: usize =
479 cmp::min(rng.gen_range(0..max_input_len), orig_len - bytes_consumed);
480
481 // write a little bit of the data
482 stream_encoder
483 .write_all(&orig_data[bytes_consumed..bytes_consumed + input_len])
484 .unwrap();
485
486 bytes_consumed += input_len;
487 }
488
489 let _ = stream_encoder.finish().unwrap();
490
491 assert_eq!(orig_len, bytes_consumed);
492 }
493
494 assert_eq!(normal_encoded, str::from_utf8(&stream_encoded).unwrap());
495 }
496 }
497
498 /// A `Write` implementation that returns Interrupted some fraction of the time, randomly.
499 struct InterruptingWriter<'a, W: 'a + Write, R: 'a + Rng> {
500 w: &'a mut W,
501 rng: &'a mut R,
502 /// In [0, 1]. If a random number in [0, 1] is `<= threshold`, `Write` methods will return
503 /// an `Interrupted` error
504 fraction: f64,
505 }
506
507 impl<'a, W: Write, R: Rng> Write for InterruptingWriter<'a, W, R> {
write(&mut self, buf: &[u8]) -> io::Result<usize>508 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
509 if self.rng.gen_range(0.0..1.0) <= self.fraction {
510 return Err(io::Error::new(io::ErrorKind::Interrupted, "interrupted"));
511 }
512
513 self.w.write(buf)
514 }
515
flush(&mut self) -> io::Result<()>516 fn flush(&mut self) -> io::Result<()> {
517 if self.rng.gen_range(0.0..1.0) <= self.fraction {
518 return Err(io::Error::new(io::ErrorKind::Interrupted, "interrupted"));
519 }
520
521 self.w.flush()
522 }
523 }
524
525 /// A `Write` implementation that sometimes will only write part of its input.
526 struct PartialInterruptingWriter<'a, W: 'a + Write, R: 'a + Rng> {
527 w: &'a mut W,
528 rng: &'a mut R,
529 /// In [0, 1]. If a random number in [0, 1] is `<= threshold`, `write()` will write all its
530 /// input. Otherwise, it will write a random substring
531 full_input_fraction: f64,
532 no_interrupt_fraction: f64,
533 }
534
535 impl<'a, W: Write, R: Rng> Write for PartialInterruptingWriter<'a, W, R> {
write(&mut self, buf: &[u8]) -> io::Result<usize>536 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
537 if self.rng.gen_range(0.0..1.0) > self.no_interrupt_fraction {
538 return Err(io::Error::new(io::ErrorKind::Interrupted, "interrupted"));
539 }
540
541 if self.rng.gen_range(0.0..1.0) <= self.full_input_fraction || buf.is_empty() {
542 // pass through the buf untouched
543 self.w.write(buf)
544 } else {
545 // only use a prefix of it
546 self.w
547 .write(&buf[0..(self.rng.gen_range(0..(buf.len() - 1)))])
548 }
549 }
550
flush(&mut self) -> io::Result<()>551 fn flush(&mut self) -> io::Result<()> {
552 self.w.flush()
553 }
554 }
555