1 #![warn(rust_2018_idioms)]
2 
3 use tokio::io::AsyncWrite;
4 use tokio_test::{assert_ready, task};
5 use tokio_util::codec::{Encoder, FramedWrite};
6 
7 use bytes::{BufMut, BytesMut};
8 use futures_sink::Sink;
9 use std::collections::VecDeque;
10 use std::io::{self, Write};
11 use std::pin::Pin;
12 use std::task::Poll::{Pending, Ready};
13 use std::task::{Context, Poll};
14 
15 macro_rules! mock {
16     ($($x:expr,)*) => {{
17         let mut v = VecDeque::new();
18         v.extend(vec![$($x),*]);
19         Mock { calls: v }
20     }};
21 }
22 
23 macro_rules! pin {
24     ($id:ident) => {
25         Pin::new(&mut $id)
26     };
27 }
28 
29 struct U32Encoder;
30 
31 impl Encoder<u32> for U32Encoder {
32     type Error = io::Error;
33 
encode(&mut self, item: u32, dst: &mut BytesMut) -> io::Result<()>34     fn encode(&mut self, item: u32, dst: &mut BytesMut) -> io::Result<()> {
35         // Reserve space
36         dst.reserve(4);
37         dst.put_u32(item);
38         Ok(())
39     }
40 }
41 
42 struct U64Encoder;
43 
44 impl Encoder<u64> for U64Encoder {
45     type Error = io::Error;
46 
encode(&mut self, item: u64, dst: &mut BytesMut) -> io::Result<()>47     fn encode(&mut self, item: u64, dst: &mut BytesMut) -> io::Result<()> {
48         // Reserve space
49         dst.reserve(8);
50         dst.put_u64(item);
51         Ok(())
52     }
53 }
54 
55 #[test]
write_multi_frame_in_packet()56 fn write_multi_frame_in_packet() {
57     let mut task = task::spawn(());
58     let mock = mock! {
59         Ok(b"\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x02".to_vec()),
60     };
61     let mut framed = FramedWrite::new(mock, U32Encoder);
62 
63     task.enter(|cx, _| {
64         assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok());
65         assert!(pin!(framed).start_send(0).is_ok());
66         assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok());
67         assert!(pin!(framed).start_send(1).is_ok());
68         assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok());
69         assert!(pin!(framed).start_send(2).is_ok());
70 
71         // Nothing written yet
72         assert_eq!(1, framed.get_ref().calls.len());
73 
74         // Flush the writes
75         assert!(assert_ready!(pin!(framed).poll_flush(cx)).is_ok());
76 
77         assert_eq!(0, framed.get_ref().calls.len());
78     });
79 }
80 
81 #[test]
write_multi_frame_after_codec_changed()82 fn write_multi_frame_after_codec_changed() {
83     let mut task = task::spawn(());
84     let mock = mock! {
85         Ok(b"\x00\x00\x00\x04\x00\x00\x00\x00\x00\x00\x00\x08".to_vec()),
86     };
87     let mut framed = FramedWrite::new(mock, U32Encoder);
88 
89     task.enter(|cx, _| {
90         assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok());
91         assert!(pin!(framed).start_send(0x04).is_ok());
92 
93         let mut framed = framed.map_encoder(|_| U64Encoder);
94         assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok());
95         assert!(pin!(framed).start_send(0x08).is_ok());
96 
97         // Nothing written yet
98         assert_eq!(1, framed.get_ref().calls.len());
99 
100         // Flush the writes
101         assert!(assert_ready!(pin!(framed).poll_flush(cx)).is_ok());
102 
103         assert_eq!(0, framed.get_ref().calls.len());
104     });
105 }
106 
107 #[test]
write_hits_backpressure()108 fn write_hits_backpressure() {
109     const ITER: usize = 2 * 1024;
110 
111     let mut mock = mock! {
112         // Block the `ITER*2`th write
113         Err(io::Error::new(io::ErrorKind::WouldBlock, "not ready")),
114         Ok(b"".to_vec()),
115     };
116 
117     for i in 0..=ITER * 2 {
118         let mut b = BytesMut::with_capacity(4);
119         b.put_u32(i as u32);
120 
121         // Append to the end
122         match mock.calls.back_mut().unwrap() {
123             Ok(ref mut data) => {
124                 // Write in 2kb chunks
125                 if data.len() < ITER {
126                     data.extend_from_slice(&b[..]);
127                     continue;
128                 } // else fall through and create a new buffer
129             }
130             _ => unreachable!(),
131         }
132 
133         // Push a new chunk
134         mock.calls.push_back(Ok(b[..].to_vec()));
135     }
136     // 1 'wouldblock', 8 * 2KB buffers, 1 b-byte buffer
137     assert_eq!(mock.calls.len(), 10);
138 
139     let mut task = task::spawn(());
140     let mut framed = FramedWrite::new(mock, U32Encoder);
141     framed.set_backpressure_boundary(ITER * 8);
142     task.enter(|cx, _| {
143         // Send 16KB. This fills up FramedWrite buffer
144         for i in 0..ITER * 2 {
145             assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok());
146             assert!(pin!(framed).start_send(i as u32).is_ok());
147         }
148 
149         // Now we poll_ready which forces a flush. The mock pops the front message
150         // and decides to block.
151         assert!(pin!(framed).poll_ready(cx).is_pending());
152 
153         // We poll again, forcing another flush, which this time succeeds
154         // The whole 16KB buffer is flushed
155         assert!(assert_ready!(pin!(framed).poll_ready(cx)).is_ok());
156 
157         // Send more data. This matches the final message expected by the mock
158         assert!(pin!(framed).start_send((ITER * 2) as u32).is_ok());
159 
160         // Flush the rest of the buffer
161         assert!(assert_ready!(pin!(framed).poll_flush(cx)).is_ok());
162 
163         // Ensure the mock is empty
164         assert_eq!(0, framed.get_ref().calls.len());
165     })
166 }
167 
168 // // ===== Mock ======
169 
170 struct Mock {
171     calls: VecDeque<io::Result<Vec<u8>>>,
172 }
173 
174 impl Write for Mock {
write(&mut self, src: &[u8]) -> io::Result<usize>175     fn write(&mut self, src: &[u8]) -> io::Result<usize> {
176         match self.calls.pop_front() {
177             Some(Ok(data)) => {
178                 assert!(src.len() >= data.len());
179                 assert_eq!(&data[..], &src[..data.len()]);
180                 Ok(data.len())
181             }
182             Some(Err(e)) => Err(e),
183             None => panic!("unexpected write; {src:?}"),
184         }
185     }
186 
flush(&mut self) -> io::Result<()>187     fn flush(&mut self) -> io::Result<()> {
188         Ok(())
189     }
190 }
191 
192 impl AsyncWrite for Mock {
poll_write( self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8], ) -> Poll<Result<usize, io::Error>>193     fn poll_write(
194         self: Pin<&mut Self>,
195         _cx: &mut Context<'_>,
196         buf: &[u8],
197     ) -> Poll<Result<usize, io::Error>> {
198         match Pin::get_mut(self).write(buf) {
199             Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Pending,
200             other => Ready(other),
201         }
202     }
poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>203     fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
204         match Pin::get_mut(self).flush() {
205             Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Pending,
206             other => Ready(other),
207         }
208     }
poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>>209     fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
210         unimplemented!()
211     }
212 }
213