1 //! Async TLS streams backed by OpenSSL.
2 //!
3 //! This crate provides a wrapper around the [`openssl`] crate's [`SslStream`](ssl::SslStream) type
4 //! that works with with [`tokio`]'s [`AsyncRead`] and [`AsyncWrite`] traits rather than std's
5 //! blocking [`Read`] and [`Write`] traits.
6 #![warn(missing_docs)]
7
8 use openssl::error::ErrorStack;
9 use openssl::ssl::{self, ErrorCode, ShutdownResult, Ssl, SslRef};
10 use std::fmt;
11 use std::future;
12 use std::io::{self, Read, Write};
13 use std::pin::Pin;
14 use std::task::{Context, Poll};
15 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
16
17 #[cfg(test)]
18 mod test;
19
20 struct StreamWrapper<S> {
21 stream: S,
22 context: usize,
23 }
24
25 impl<S> fmt::Debug for StreamWrapper<S>
26 where
27 S: fmt::Debug,
28 {
fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result29 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
30 fmt::Debug::fmt(&self.stream, fmt)
31 }
32 }
33
34 impl<S> StreamWrapper<S> {
35 /// # Safety
36 ///
37 /// Must be called with `context` set to a valid pointer to a live `Context` object, and the
38 /// wrapper must be pinned in memory.
parts(&mut self) -> (Pin<&mut S>, &mut Context<'_>)39 unsafe fn parts(&mut self) -> (Pin<&mut S>, &mut Context<'_>) {
40 debug_assert_ne!(self.context, 0);
41 let stream = Pin::new_unchecked(&mut self.stream);
42 let context = &mut *(self.context as *mut _);
43 (stream, context)
44 }
45 }
46
47 impl<S> Read for StreamWrapper<S>
48 where
49 S: AsyncRead,
50 {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>51 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
52 let (stream, cx) = unsafe { self.parts() };
53 let mut buf = ReadBuf::new(buf);
54 match stream.poll_read(cx, &mut buf)? {
55 Poll::Ready(()) => Ok(buf.filled().len()),
56 Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
57 }
58 }
59 }
60
61 impl<S> Write for StreamWrapper<S>
62 where
63 S: AsyncWrite,
64 {
write(&mut self, buf: &[u8]) -> io::Result<usize>65 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
66 let (stream, cx) = unsafe { self.parts() };
67 match stream.poll_write(cx, buf) {
68 Poll::Ready(r) => r,
69 Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
70 }
71 }
72
flush(&mut self) -> io::Result<()>73 fn flush(&mut self) -> io::Result<()> {
74 let (stream, cx) = unsafe { self.parts() };
75 match stream.poll_flush(cx) {
76 Poll::Ready(r) => r,
77 Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
78 }
79 }
80 }
81
cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>>82 fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
83 match r {
84 Ok(v) => Poll::Ready(Ok(v)),
85 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
86 Err(e) => Poll::Ready(Err(e)),
87 }
88 }
89
cvt_ossl<T>(r: Result<T, ssl::Error>) -> Poll<Result<T, ssl::Error>>90 fn cvt_ossl<T>(r: Result<T, ssl::Error>) -> Poll<Result<T, ssl::Error>> {
91 match r {
92 Ok(v) => Poll::Ready(Ok(v)),
93 Err(e) => match e.code() {
94 ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Poll::Pending,
95 _ => Poll::Ready(Err(e)),
96 },
97 }
98 }
99
100 /// An asynchronous version of [`openssl::ssl::SslStream`].
101 #[derive(Debug)]
102 pub struct SslStream<S>(ssl::SslStream<StreamWrapper<S>>);
103
104 impl<S> SslStream<S>
105 where
106 S: AsyncRead + AsyncWrite,
107 {
108 /// Like [`SslStream::new`](ssl::SslStream::new).
new(ssl: Ssl, stream: S) -> Result<Self, ErrorStack>109 pub fn new(ssl: Ssl, stream: S) -> Result<Self, ErrorStack> {
110 ssl::SslStream::new(ssl, StreamWrapper { stream, context: 0 }).map(SslStream)
111 }
112
113 /// Like [`SslStream::connect`](ssl::SslStream::connect).
poll_connect( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Result<(), ssl::Error>>114 pub fn poll_connect(
115 self: Pin<&mut Self>,
116 cx: &mut Context<'_>,
117 ) -> Poll<Result<(), ssl::Error>> {
118 self.with_context(cx, |s| cvt_ossl(s.connect()))
119 }
120
121 /// A convenience method wrapping [`poll_connect`](Self::poll_connect).
connect(mut self: Pin<&mut Self>) -> Result<(), ssl::Error>122 pub async fn connect(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
123 future::poll_fn(|cx| self.as_mut().poll_connect(cx)).await
124 }
125
126 /// Like [`SslStream::accept`](ssl::SslStream::accept).
poll_accept(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), ssl::Error>>127 pub fn poll_accept(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), ssl::Error>> {
128 self.with_context(cx, |s| cvt_ossl(s.accept()))
129 }
130
131 /// A convenience method wrapping [`poll_accept`](Self::poll_accept).
accept(mut self: Pin<&mut Self>) -> Result<(), ssl::Error>132 pub async fn accept(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
133 future::poll_fn(|cx| self.as_mut().poll_accept(cx)).await
134 }
135
136 /// Like [`SslStream::do_handshake`](ssl::SslStream::do_handshake).
poll_do_handshake( self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Result<(), ssl::Error>>137 pub fn poll_do_handshake(
138 self: Pin<&mut Self>,
139 cx: &mut Context<'_>,
140 ) -> Poll<Result<(), ssl::Error>> {
141 self.with_context(cx, |s| cvt_ossl(s.do_handshake()))
142 }
143
144 /// A convenience method wrapping [`poll_do_handshake`](Self::poll_do_handshake).
do_handshake(mut self: Pin<&mut Self>) -> Result<(), ssl::Error>145 pub async fn do_handshake(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
146 future::poll_fn(|cx| self.as_mut().poll_do_handshake(cx)).await
147 }
148
149 /// Like [`SslStream::ssl_peek`](ssl::SslStream::ssl_peek).
poll_peek( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll<Result<usize, ssl::Error>>150 pub fn poll_peek(
151 self: Pin<&mut Self>,
152 cx: &mut Context<'_>,
153 buf: &mut [u8],
154 ) -> Poll<Result<usize, ssl::Error>> {
155 self.with_context(cx, |s| cvt_ossl(s.ssl_peek(buf)))
156 }
157
158 /// A convenience method wrapping [`poll_peek`](Self::poll_peek).
peek(mut self: Pin<&mut Self>, buf: &mut [u8]) -> Result<usize, ssl::Error>159 pub async fn peek(mut self: Pin<&mut Self>, buf: &mut [u8]) -> Result<usize, ssl::Error> {
160 future::poll_fn(|cx| self.as_mut().poll_peek(cx, buf)).await
161 }
162
163 /// Like [`SslStream::read_early_data`](ssl::SslStream::read_early_data).
164 #[cfg(ossl111)]
poll_read_early_data( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll<Result<usize, ssl::Error>>165 pub fn poll_read_early_data(
166 self: Pin<&mut Self>,
167 cx: &mut Context<'_>,
168 buf: &mut [u8],
169 ) -> Poll<Result<usize, ssl::Error>> {
170 self.with_context(cx, |s| cvt_ossl(s.read_early_data(buf)))
171 }
172
173 /// A convenience method wrapping [`poll_read_early_data`](Self::poll_read_early_data).
174 #[cfg(ossl111)]
read_early_data( mut self: Pin<&mut Self>, buf: &mut [u8], ) -> Result<usize, ssl::Error>175 pub async fn read_early_data(
176 mut self: Pin<&mut Self>,
177 buf: &mut [u8],
178 ) -> Result<usize, ssl::Error> {
179 future::poll_fn(|cx| self.as_mut().poll_read_early_data(cx, buf)).await
180 }
181
182 /// Like [`SslStream::write_early_data`](ssl::SslStream::write_early_data).
183 #[cfg(ossl111)]
poll_write_early_data( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<Result<usize, ssl::Error>>184 pub fn poll_write_early_data(
185 self: Pin<&mut Self>,
186 cx: &mut Context<'_>,
187 buf: &[u8],
188 ) -> Poll<Result<usize, ssl::Error>> {
189 self.with_context(cx, |s| cvt_ossl(s.write_early_data(buf)))
190 }
191
192 /// A convenience method wrapping [`poll_write_early_data`](Self::poll_write_early_data).
193 #[cfg(ossl111)]
write_early_data( mut self: Pin<&mut Self>, buf: &[u8], ) -> Result<usize, ssl::Error>194 pub async fn write_early_data(
195 mut self: Pin<&mut Self>,
196 buf: &[u8],
197 ) -> Result<usize, ssl::Error> {
198 future::poll_fn(|cx| self.as_mut().poll_write_early_data(cx, buf)).await
199 }
200 }
201
202 impl<S> SslStream<S> {
203 /// Returns a shared reference to the `Ssl` object associated with this stream.
ssl(&self) -> &SslRef204 pub fn ssl(&self) -> &SslRef {
205 self.0.ssl()
206 }
207
208 /// Returns a shared reference to the underlying stream.
get_ref(&self) -> &S209 pub fn get_ref(&self) -> &S {
210 &self.0.get_ref().stream
211 }
212
213 /// Returns a mutable reference to the underlying stream.
get_mut(&mut self) -> &mut S214 pub fn get_mut(&mut self) -> &mut S {
215 &mut self.0.get_mut().stream
216 }
217
218 /// Returns a pinned mutable reference to the underlying stream.
get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S>219 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
220 unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().0.get_mut().stream) }
221 }
222
with_context<F, R>(self: Pin<&mut Self>, ctx: &mut Context<'_>, f: F) -> R where F: FnOnce(&mut ssl::SslStream<StreamWrapper<S>>) -> R,223 fn with_context<F, R>(self: Pin<&mut Self>, ctx: &mut Context<'_>, f: F) -> R
224 where
225 F: FnOnce(&mut ssl::SslStream<StreamWrapper<S>>) -> R,
226 {
227 let this = unsafe { self.get_unchecked_mut() };
228 this.0.get_mut().context = ctx as *mut _ as usize;
229 let r = f(&mut this.0);
230 this.0.get_mut().context = 0;
231 r
232 }
233 }
234
235 impl<S> AsyncRead for SslStream<S>
236 where
237 S: AsyncRead + AsyncWrite,
238 {
poll_read( self: Pin<&mut Self>, ctx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>239 fn poll_read(
240 self: Pin<&mut Self>,
241 ctx: &mut Context<'_>,
242 buf: &mut ReadBuf<'_>,
243 ) -> Poll<io::Result<()>> {
244 self.with_context(ctx, |s| {
245 // SAFETY: read_uninit does not de-initialize the buffer.
246 match cvt(s.read_uninit(unsafe { buf.unfilled_mut() }))? {
247 Poll::Ready(nread) => {
248 // SAFETY: read_uninit guarantees that nread bytes have been initialized.
249 unsafe { buf.assume_init(nread) };
250 buf.advance(nread);
251 Poll::Ready(Ok(()))
252 }
253 Poll::Pending => Poll::Pending,
254 }
255 })
256 }
257 }
258
259 impl<S> AsyncWrite for SslStream<S>
260 where
261 S: AsyncRead + AsyncWrite,
262 {
poll_write(self: Pin<&mut Self>, ctx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>>263 fn poll_write(self: Pin<&mut Self>, ctx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
264 self.with_context(ctx, |s| cvt(s.write(buf)))
265 }
266
poll_flush(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>>267 fn poll_flush(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
268 self.with_context(ctx, |s| cvt(s.flush()))
269 }
270
poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>>271 fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
272 match self.as_mut().with_context(ctx, |s| s.shutdown()) {
273 Ok(ShutdownResult::Sent) | Ok(ShutdownResult::Received) => {}
274 Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => {}
275 Err(ref e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => {
276 return Poll::Pending;
277 }
278 Err(e) => {
279 return Poll::Ready(Err(e
280 .into_io_error()
281 .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e))));
282 }
283 }
284
285 self.get_pin_mut().poll_shutdown(ctx)
286 }
287 }
288