1 use std::future::Future;
2 use std::io;
3 use std::pin::Pin;
4 use std::task::{Context, Poll};
5 use std::time::Duration;
6 
7 use tokio::io::{AsyncRead, AsyncWrite};
8 use tokio::time::timeout;
9 use tokio_io_timeout::TimeoutStream;
10 
11 use hyper::client::connect::{Connected, Connection};
12 use hyper::{service::Service, Uri};
13 
14 mod stream;
15 
16 use stream::TimeoutConnectorStream;
17 
18 type BoxError = Box<dyn std::error::Error + Send + Sync>;
19 
20 /// A connector that enforces as connection timeout
21 #[derive(Debug, Clone)]
22 pub struct TimeoutConnector<T> {
23     /// A connector implementing the `Connect` trait
24     connector: T,
25     /// Amount of time to wait connecting
26     connect_timeout: Option<Duration>,
27     /// Amount of time to wait reading response
28     read_timeout: Option<Duration>,
29     /// Amount of time to wait writing request
30     write_timeout: Option<Duration>,
31 }
32 
33 impl<T> TimeoutConnector<T>
34 where
35     T: Service<Uri> + Send,
36     T::Response: AsyncRead + AsyncWrite + Send + Unpin,
37     T::Future: Send + 'static,
38     T::Error: Into<BoxError>,
39 {
40     /// Construct a new TimeoutConnector with a given connector implementing the `Connect` trait
new(connector: T) -> Self41     pub fn new(connector: T) -> Self {
42         TimeoutConnector {
43             connector,
44             connect_timeout: None,
45             read_timeout: None,
46             write_timeout: None,
47         }
48     }
49 }
50 
51 impl<T> Service<Uri> for TimeoutConnector<T>
52 where
53     T: Service<Uri> + Send,
54     T::Response: AsyncRead + AsyncWrite + Connection + Send + Unpin,
55     T::Future: Send + 'static,
56     T::Error: Into<BoxError>,
57 {
58     type Response = Pin<Box<TimeoutConnectorStream<T::Response>>>;
59     type Error = BoxError;
60     #[allow(clippy::type_complexity)]
61     type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
62 
poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>63     fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
64         self.connector.poll_ready(cx).map_err(Into::into)
65     }
66 
call(&mut self, dst: Uri) -> Self::Future67     fn call(&mut self, dst: Uri) -> Self::Future {
68         let connect_timeout = self.connect_timeout;
69         let read_timeout = self.read_timeout;
70         let write_timeout = self.write_timeout;
71         let connecting = self.connector.call(dst);
72 
73         let fut = async move {
74             let stream = match connect_timeout {
75                 None => {
76                     let io = connecting.await.map_err(Into::into)?;
77                     TimeoutStream::new(io)
78                 }
79                 Some(connect_timeout) => {
80                     let timeout = timeout(connect_timeout, connecting);
81                     let connecting = timeout
82                         .await
83                         .map_err(|e| io::Error::new(io::ErrorKind::TimedOut, e))?;
84                     let io = connecting.map_err(Into::into)?;
85                     TimeoutStream::new(io)
86                 }
87             };
88 
89             let mut tm = TimeoutConnectorStream::new(stream);
90             tm.set_read_timeout(read_timeout);
91             tm.set_write_timeout(write_timeout);
92             Ok(Box::pin(tm))
93         };
94 
95         Box::pin(fut)
96     }
97 }
98 
99 impl<T> TimeoutConnector<T> {
100     /// Set the timeout for connecting to a URL.
101     ///
102     /// Default is no timeout.
103     #[inline]
set_connect_timeout(&mut self, val: Option<Duration>)104     pub fn set_connect_timeout(&mut self, val: Option<Duration>) {
105         self.connect_timeout = val;
106     }
107 
108     /// Set the timeout for the response.
109     ///
110     /// Default is no timeout.
111     #[inline]
set_read_timeout(&mut self, val: Option<Duration>)112     pub fn set_read_timeout(&mut self, val: Option<Duration>) {
113         self.read_timeout = val;
114     }
115 
116     /// Set the timeout for the request.
117     ///
118     /// Default is no timeout.
119     #[inline]
set_write_timeout(&mut self, val: Option<Duration>)120     pub fn set_write_timeout(&mut self, val: Option<Duration>) {
121         self.write_timeout = val;
122     }
123 }
124 
125 impl<T> Connection for TimeoutConnector<T>
126 where
127     T: AsyncRead + AsyncWrite + Connection + Service<Uri> + Send + Unpin,
128     T::Response: AsyncRead + AsyncWrite + Send + Unpin,
129     T::Future: Send + 'static,
130     T::Error: Into<BoxError>,
131 {
connected(&self) -> Connected132     fn connected(&self) -> Connected {
133         self.connector.connected()
134     }
135 }
136 
137 #[cfg(test)]
138 mod tests {
139     use std::error::Error;
140     use std::io;
141     use std::time::Duration;
142 
143     use hyper::client::HttpConnector;
144     use hyper::Client;
145 
146     use super::TimeoutConnector;
147 
148     #[tokio::test]
test_timeout_connector()149     async fn test_timeout_connector() {
150         // 10.255.255.1 is a not a routable IP address
151         let url = "http://10.255.255.1".parse().unwrap();
152 
153         let http = HttpConnector::new();
154         let mut connector = TimeoutConnector::new(http);
155         connector.set_connect_timeout(Some(Duration::from_millis(1)));
156 
157         let client = Client::builder().build::<_, hyper::Body>(connector);
158 
159         let res = client.get(url).await;
160 
161         match res {
162             Ok(_) => panic!("Expected a timeout"),
163             Err(e) => {
164                 if let Some(io_e) = e.source().unwrap().downcast_ref::<io::Error>() {
165                     assert_eq!(io_e.kind(), io::ErrorKind::TimedOut);
166                 } else {
167                     panic!("Expected timeout error");
168                 }
169             }
170         }
171     }
172 
173     #[tokio::test]
test_read_timeout()174     async fn test_read_timeout() {
175         let url = "http://example.com".parse().unwrap();
176 
177         let http = HttpConnector::new();
178         let mut connector = TimeoutConnector::new(http);
179         // A 1 ms read timeout should be so short that we trigger a timeout error
180         connector.set_read_timeout(Some(Duration::from_millis(1)));
181 
182         let client = Client::builder().build::<_, hyper::Body>(connector);
183 
184         let res = client.get(url).await;
185 
186         match res {
187             Ok(_) => panic!("Expected a timeout"),
188             Err(e) => {
189                 if let Some(io_e) = e.source().unwrap().downcast_ref::<io::Error>() {
190                     assert_eq!(io_e.kind(), io::ErrorKind::TimedOut);
191                 } else {
192                     panic!("Expected timeout error");
193                 }
194             }
195         }
196     }
197 }
198