1 #![warn(rust_2018_idioms)]
2 #![cfg(feature = "full")]
3 
4 use tokio::io::{AsyncWrite, AsyncWriteExt};
5 use tokio_test::{assert_err, assert_ok};
6 
7 use bytes::{Buf, Bytes, BytesMut};
8 use std::cmp;
9 use std::io;
10 use std::pin::Pin;
11 use std::task::{Context, Poll};
12 
13 #[tokio::test]
write_all_buf()14 async fn write_all_buf() {
15     struct Wr {
16         buf: BytesMut,
17         cnt: usize,
18     }
19 
20     impl AsyncWrite for Wr {
21         fn poll_write(
22             mut self: Pin<&mut Self>,
23             _cx: &mut Context<'_>,
24             buf: &[u8],
25         ) -> Poll<io::Result<usize>> {
26             let n = cmp::min(4, buf.len());
27             dbg!(buf);
28             let buf = &buf[0..n];
29 
30             self.cnt += 1;
31             self.buf.extend(buf);
32             Ok(buf.len()).into()
33         }
34 
35         fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
36             Ok(()).into()
37         }
38 
39         fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
40             Ok(()).into()
41         }
42     }
43 
44     let mut wr = Wr {
45         buf: BytesMut::with_capacity(64),
46         cnt: 0,
47     };
48 
49     let mut buf = Bytes::from_static(b"hello").chain(Bytes::from_static(b"world"));
50 
51     assert_ok!(wr.write_all_buf(&mut buf).await);
52     assert_eq!(wr.buf, b"helloworld"[..]);
53     // expect 4 writes, [hell],[o],[worl],[d]
54     assert_eq!(wr.cnt, 4);
55     assert!(!buf.has_remaining());
56 }
57 
58 #[tokio::test]
write_buf_err()59 async fn write_buf_err() {
60     /// Error out after writing the first 4 bytes
61     struct Wr {
62         cnt: usize,
63     }
64 
65     impl AsyncWrite for Wr {
66         fn poll_write(
67             mut self: Pin<&mut Self>,
68             _cx: &mut Context<'_>,
69             _buf: &[u8],
70         ) -> Poll<io::Result<usize>> {
71             self.cnt += 1;
72             if self.cnt == 2 {
73                 return Poll::Ready(Err(io::Error::new(io::ErrorKind::Other, "whoops")));
74             }
75             Poll::Ready(Ok(4))
76         }
77 
78         fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
79             Ok(()).into()
80         }
81 
82         fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
83             Ok(()).into()
84         }
85     }
86 
87     let mut wr = Wr { cnt: 0 };
88 
89     let mut buf = Bytes::from_static(b"hello").chain(Bytes::from_static(b"world"));
90 
91     assert_err!(wr.write_all_buf(&mut buf).await);
92     assert_eq!(
93         buf.copy_to_bytes(buf.remaining()),
94         Bytes::from_static(b"oworld")
95     );
96 }
97 
98 #[tokio::test]
write_all_buf_vectored()99 async fn write_all_buf_vectored() {
100     struct Wr {
101         buf: BytesMut,
102     }
103     impl AsyncWrite for Wr {
104         fn poll_write(
105             self: Pin<&mut Self>,
106             _cx: &mut Context<'_>,
107             _buf: &[u8],
108         ) -> Poll<io::Result<usize>> {
109             // When executing `write_all_buf` with this writer,
110             // `poll_write` is not called.
111             panic!("shouldn't be called")
112         }
113         fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
114             Ok(()).into()
115         }
116         fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
117             Ok(()).into()
118         }
119         fn poll_write_vectored(
120             mut self: Pin<&mut Self>,
121             _cx: &mut Context<'_>,
122             bufs: &[io::IoSlice<'_>],
123         ) -> Poll<Result<usize, io::Error>> {
124             for buf in bufs {
125                 self.buf.extend_from_slice(buf);
126             }
127             let n = self.buf.len();
128             Ok(n).into()
129         }
130         fn is_write_vectored(&self) -> bool {
131             // Enable vectored write.
132             true
133         }
134     }
135 
136     let mut wr = Wr {
137         buf: BytesMut::with_capacity(64),
138     };
139     let mut buf = Bytes::from_static(b"hello")
140         .chain(Bytes::from_static(b" "))
141         .chain(Bytes::from_static(b"world"));
142 
143     wr.write_all_buf(&mut buf).await.unwrap();
144     assert_eq!(&wr.buf[..], b"hello world");
145 }
146