1 //! HTTP Upgrades
2 //!
3 //! This module deals with managing [HTTP Upgrades][mdn] in hyper. Since
4 //! several concepts in HTTP allow for first talking HTTP, and then converting
5 //! to a different protocol, this module conflates them into a single API.
6 //! Those include:
7 //!
8 //! - HTTP/1.1 Upgrades
9 //! - HTTP `CONNECT`
10 //!
11 //! You are responsible for any other pre-requisites to establish an upgrade,
12 //! such as sending the appropriate headers, methods, and status codes. You can
13 //! then use [`on`][] to grab a `Future` which will resolve to the upgraded
14 //! connection object, or an error if the upgrade fails.
15 //!
16 //! [mdn]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Protocol_upgrade_mechanism
17 //!
18 //! # Client
19 //!
20 //! Sending an HTTP upgrade from the [`client`](super::client) involves setting
21 //! either the appropriate method, if wanting to `CONNECT`, or headers such as
22 //! `Upgrade` and `Connection`, on the `http::Request`. Once receiving the
23 //! `http::Response` back, you must check for the specific information that the
24 //! upgrade is agreed upon by the server (such as a `101` status code), and then
25 //! get the `Future` from the `Response`.
26 //!
27 //! # Server
28 //!
29 //! Receiving upgrade requests in a server requires you to check the relevant
30 //! headers in a `Request`, and if an upgrade should be done, you then send the
31 //! corresponding headers in a response. To then wait for hyper to finish the
32 //! upgrade, you call `on()` with the `Request`, and then can spawn a task
33 //! awaiting it.
34 //!
35 //! # Example
36 //!
37 //! See [this example][example] showing how upgrades work with both
38 //! Clients and Servers.
39 //!
40 //! [example]: https://github.com/hyperium/hyper/blob/master/examples/upgrades.rs
41 
42 use std::any::TypeId;
43 use std::error::Error as StdError;
44 use std::fmt;
45 use std::future::Future;
46 use std::io;
47 use std::marker::Unpin;
48 use std::pin::Pin;
49 use std::task::{Context, Poll};
50 
51 use bytes::Bytes;
52 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
53 use tokio::sync::oneshot;
54 #[cfg(any(feature = "http1", feature = "http2"))]
55 use tracing::trace;
56 
57 use crate::common::io::Rewind;
58 
59 /// An upgraded HTTP connection.
60 ///
61 /// This type holds a trait object internally of the original IO that
62 /// was used to speak HTTP before the upgrade. It can be used directly
63 /// as a `Read` or `Write` for convenience.
64 ///
65 /// Alternatively, if the exact type is known, this can be deconstructed
66 /// into its parts.
67 pub struct Upgraded {
68     io: Rewind<Box<dyn Io + Send>>,
69 }
70 
71 /// A future for a possible HTTP upgrade.
72 ///
73 /// If no upgrade was available, or it doesn't succeed, yields an `Error`.
74 pub struct OnUpgrade {
75     rx: Option<oneshot::Receiver<crate::Result<Upgraded>>>,
76 }
77 
78 /// The deconstructed parts of an [`Upgraded`](Upgraded) type.
79 ///
80 /// Includes the original IO type, and a read buffer of bytes that the
81 /// HTTP state machine may have already read before completing an upgrade.
82 #[derive(Debug)]
83 pub struct Parts<T> {
84     /// The original IO object used before the upgrade.
85     pub io: T,
86     /// A buffer of bytes that have been read but not processed as HTTP.
87     ///
88     /// For instance, if the `Connection` is used for an HTTP upgrade request,
89     /// it is possible the server sent back the first bytes of the new protocol
90     /// along with the response upgrade.
91     ///
92     /// You will want to check for any existing bytes if you plan to continue
93     /// communicating on the IO object.
94     pub read_buf: Bytes,
95     _inner: (),
96 }
97 
98 /// Gets a pending HTTP upgrade from this message.
99 ///
100 /// This can be called on the following types:
101 ///
102 /// - `http::Request<B>`
103 /// - `http::Response<B>`
104 /// - `&mut http::Request<B>`
105 /// - `&mut http::Response<B>`
on<T: sealed::CanUpgrade>(msg: T) -> OnUpgrade106 pub fn on<T: sealed::CanUpgrade>(msg: T) -> OnUpgrade {
107     msg.on_upgrade()
108 }
109 
110 #[cfg(any(feature = "http1", feature = "http2"))]
111 pub(super) struct Pending {
112     tx: oneshot::Sender<crate::Result<Upgraded>>,
113 }
114 
115 #[cfg(any(feature = "http1", feature = "http2"))]
pending() -> (Pending, OnUpgrade)116 pub(super) fn pending() -> (Pending, OnUpgrade) {
117     let (tx, rx) = oneshot::channel();
118     (Pending { tx }, OnUpgrade { rx: Some(rx) })
119 }
120 
121 // ===== impl Upgraded =====
122 
123 impl Upgraded {
124     #[cfg(any(feature = "http1", feature = "http2", test))]
new<T>(io: T, read_buf: Bytes) -> Self where T: AsyncRead + AsyncWrite + Unpin + Send + 'static,125     pub(super) fn new<T>(io: T, read_buf: Bytes) -> Self
126     where
127         T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
128     {
129         Upgraded {
130             io: Rewind::new_buffered(Box::new(io), read_buf),
131         }
132     }
133 
134     /// Tries to downcast the internal trait object to the type passed.
135     ///
136     /// On success, returns the downcasted parts. On error, returns the
137     /// `Upgraded` back.
downcast<T: AsyncRead + AsyncWrite + Unpin + 'static>(self) -> Result<Parts<T>, Self>138     pub fn downcast<T: AsyncRead + AsyncWrite + Unpin + 'static>(self) -> Result<Parts<T>, Self> {
139         let (io, buf) = self.io.into_inner();
140         match io.__hyper_downcast() {
141             Ok(t) => Ok(Parts {
142                 io: *t,
143                 read_buf: buf,
144                 _inner: (),
145             }),
146             Err(io) => Err(Upgraded {
147                 io: Rewind::new_buffered(io, buf),
148             }),
149         }
150     }
151 }
152 
153 impl AsyncRead for Upgraded {
poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>154     fn poll_read(
155         mut self: Pin<&mut Self>,
156         cx: &mut Context<'_>,
157         buf: &mut ReadBuf<'_>,
158     ) -> Poll<io::Result<()>> {
159         Pin::new(&mut self.io).poll_read(cx, buf)
160     }
161 }
162 
163 impl AsyncWrite for Upgraded {
poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>164     fn poll_write(
165         mut self: Pin<&mut Self>,
166         cx: &mut Context<'_>,
167         buf: &[u8],
168     ) -> Poll<io::Result<usize>> {
169         Pin::new(&mut self.io).poll_write(cx, buf)
170     }
171 
poll_write_vectored( mut self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>], ) -> Poll<io::Result<usize>>172     fn poll_write_vectored(
173         mut self: Pin<&mut Self>,
174         cx: &mut Context<'_>,
175         bufs: &[io::IoSlice<'_>],
176     ) -> Poll<io::Result<usize>> {
177         Pin::new(&mut self.io).poll_write_vectored(cx, bufs)
178     }
179 
poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>180     fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
181         Pin::new(&mut self.io).poll_flush(cx)
182     }
183 
poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>184     fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
185         Pin::new(&mut self.io).poll_shutdown(cx)
186     }
187 
is_write_vectored(&self) -> bool188     fn is_write_vectored(&self) -> bool {
189         self.io.is_write_vectored()
190     }
191 }
192 
193 impl fmt::Debug for Upgraded {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result194     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195         f.debug_struct("Upgraded").finish()
196     }
197 }
198 
199 // ===== impl OnUpgrade =====
200 
201 impl OnUpgrade {
none() -> Self202     pub(super) fn none() -> Self {
203         OnUpgrade { rx: None }
204     }
205 
206     #[cfg(feature = "http1")]
is_none(&self) -> bool207     pub(super) fn is_none(&self) -> bool {
208         self.rx.is_none()
209     }
210 }
211 
212 impl Future for OnUpgrade {
213     type Output = Result<Upgraded, crate::Error>;
214 
poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>215     fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
216         match self.rx {
217             Some(ref mut rx) => Pin::new(rx).poll(cx).map(|res| match res {
218                 Ok(Ok(upgraded)) => Ok(upgraded),
219                 Ok(Err(err)) => Err(err),
220                 Err(_oneshot_canceled) => Err(crate::Error::new_canceled().with(UpgradeExpected)),
221             }),
222             None => Poll::Ready(Err(crate::Error::new_user_no_upgrade())),
223         }
224     }
225 }
226 
227 impl fmt::Debug for OnUpgrade {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result228     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
229         f.debug_struct("OnUpgrade").finish()
230     }
231 }
232 
233 // ===== impl Pending =====
234 
235 #[cfg(any(feature = "http1", feature = "http2"))]
236 impl Pending {
fulfill(self, upgraded: Upgraded)237     pub(super) fn fulfill(self, upgraded: Upgraded) {
238         trace!("pending upgrade fulfill");
239         let _ = self.tx.send(Ok(upgraded));
240     }
241 
242     #[cfg(feature = "http1")]
243     /// Don't fulfill the pending Upgrade, but instead signal that
244     /// upgrades are handled manually.
manual(self)245     pub(super) fn manual(self) {
246         trace!("pending upgrade handled manually");
247         let _ = self.tx.send(Err(crate::Error::new_user_manual_upgrade()));
248     }
249 }
250 
251 // ===== impl UpgradeExpected =====
252 
253 /// Error cause returned when an upgrade was expected but canceled
254 /// for whatever reason.
255 ///
256 /// This likely means the actual `Conn` future wasn't polled and upgraded.
257 #[derive(Debug)]
258 struct UpgradeExpected;
259 
260 impl fmt::Display for UpgradeExpected {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result261     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
262         f.write_str("upgrade expected but not completed")
263     }
264 }
265 
266 impl StdError for UpgradeExpected {}
267 
268 // ===== impl Io =====
269 
270 pub(super) trait Io: AsyncRead + AsyncWrite + Unpin + 'static {
__hyper_type_id(&self) -> TypeId271     fn __hyper_type_id(&self) -> TypeId {
272         TypeId::of::<Self>()
273     }
274 }
275 
276 impl<T: AsyncRead + AsyncWrite + Unpin + 'static> Io for T {}
277 
278 impl dyn Io + Send {
__hyper_is<T: Io>(&self) -> bool279     fn __hyper_is<T: Io>(&self) -> bool {
280         let t = TypeId::of::<T>();
281         self.__hyper_type_id() == t
282     }
283 
__hyper_downcast<T: Io>(self: Box<Self>) -> Result<Box<T>, Box<Self>>284     fn __hyper_downcast<T: Io>(self: Box<Self>) -> Result<Box<T>, Box<Self>> {
285         if self.__hyper_is::<T>() {
286             // Taken from `std::error::Error::downcast()`.
287             unsafe {
288                 let raw: *mut dyn Io = Box::into_raw(self);
289                 Ok(Box::from_raw(raw as *mut T))
290             }
291         } else {
292             Err(self)
293         }
294     }
295 }
296 
297 mod sealed {
298     use super::OnUpgrade;
299 
300     pub trait CanUpgrade {
on_upgrade(self) -> OnUpgrade301         fn on_upgrade(self) -> OnUpgrade;
302     }
303 
304     impl<B> CanUpgrade for http::Request<B> {
on_upgrade(mut self) -> OnUpgrade305         fn on_upgrade(mut self) -> OnUpgrade {
306             self.extensions_mut()
307                 .remove::<OnUpgrade>()
308                 .unwrap_or_else(OnUpgrade::none)
309         }
310     }
311 
312     impl<B> CanUpgrade for &'_ mut http::Request<B> {
on_upgrade(self) -> OnUpgrade313         fn on_upgrade(self) -> OnUpgrade {
314             self.extensions_mut()
315                 .remove::<OnUpgrade>()
316                 .unwrap_or_else(OnUpgrade::none)
317         }
318     }
319 
320     impl<B> CanUpgrade for http::Response<B> {
on_upgrade(mut self) -> OnUpgrade321         fn on_upgrade(mut self) -> OnUpgrade {
322             self.extensions_mut()
323                 .remove::<OnUpgrade>()
324                 .unwrap_or_else(OnUpgrade::none)
325         }
326     }
327 
328     impl<B> CanUpgrade for &'_ mut http::Response<B> {
on_upgrade(self) -> OnUpgrade329         fn on_upgrade(self) -> OnUpgrade {
330             self.extensions_mut()
331                 .remove::<OnUpgrade>()
332                 .unwrap_or_else(OnUpgrade::none)
333         }
334     }
335 }
336 
337 #[cfg(test)]
338 mod tests {
339     use super::*;
340 
341     #[test]
upgraded_downcast()342     fn upgraded_downcast() {
343         let upgraded = Upgraded::new(Mock, Bytes::new());
344 
345         let upgraded = upgraded.downcast::<std::io::Cursor<Vec<u8>>>().unwrap_err();
346 
347         upgraded.downcast::<Mock>().unwrap();
348     }
349 
350     // TODO: replace with tokio_test::io when it can test write_buf
351     struct Mock;
352 
353     impl AsyncRead for Mock {
poll_read( self: Pin<&mut Self>, _cx: &mut Context<'_>, _buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>354         fn poll_read(
355             self: Pin<&mut Self>,
356             _cx: &mut Context<'_>,
357             _buf: &mut ReadBuf<'_>,
358         ) -> Poll<io::Result<()>> {
359             unreachable!("Mock::poll_read")
360         }
361     }
362 
363     impl AsyncWrite for Mock {
poll_write( self: Pin<&mut Self>, _: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>364         fn poll_write(
365             self: Pin<&mut Self>,
366             _: &mut Context<'_>,
367             buf: &[u8],
368         ) -> Poll<io::Result<usize>> {
369             // panic!("poll_write shouldn't be called");
370             Poll::Ready(Ok(buf.len()))
371         }
372 
poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>>373         fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
374             unreachable!("Mock::poll_flush")
375         }
376 
poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>>377         fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
378             unreachable!("Mock::poll_shutdown")
379         }
380     }
381 }
382