1 use std::future::Future; 2 use std::io; 3 use std::pin::Pin; 4 use std::task::{Context, Poll}; 5 use std::time::Duration; 6 7 use tokio::io::{AsyncRead, AsyncWrite}; 8 use tokio::time::timeout; 9 use tokio_io_timeout::TimeoutStream; 10 11 use hyper::client::connect::{Connected, Connection}; 12 use hyper::{service::Service, Uri}; 13 14 mod stream; 15 16 use stream::TimeoutConnectorStream; 17 18 type BoxError = Box<dyn std::error::Error + Send + Sync>; 19 20 /// A connector that enforces as connection timeout 21 #[derive(Debug, Clone)] 22 pub struct TimeoutConnector<T> { 23 /// A connector implementing the `Connect` trait 24 connector: T, 25 /// Amount of time to wait connecting 26 connect_timeout: Option<Duration>, 27 /// Amount of time to wait reading response 28 read_timeout: Option<Duration>, 29 /// Amount of time to wait writing request 30 write_timeout: Option<Duration>, 31 } 32 33 impl<T> TimeoutConnector<T> 34 where 35 T: Service<Uri> + Send, 36 T::Response: AsyncRead + AsyncWrite + Send + Unpin, 37 T::Future: Send + 'static, 38 T::Error: Into<BoxError>, 39 { 40 /// Construct a new TimeoutConnector with a given connector implementing the `Connect` trait new(connector: T) -> Self41 pub fn new(connector: T) -> Self { 42 TimeoutConnector { 43 connector, 44 connect_timeout: None, 45 read_timeout: None, 46 write_timeout: None, 47 } 48 } 49 } 50 51 impl<T> Service<Uri> for TimeoutConnector<T> 52 where 53 T: Service<Uri> + Send, 54 T::Response: AsyncRead + AsyncWrite + Connection + Send + Unpin, 55 T::Future: Send + 'static, 56 T::Error: Into<BoxError>, 57 { 58 type Response = Pin<Box<TimeoutConnectorStream<T::Response>>>; 59 type Error = BoxError; 60 #[allow(clippy::type_complexity)] 61 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>; 62 poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>63 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { 64 self.connector.poll_ready(cx).map_err(Into::into) 65 } 66 call(&mut self, dst: Uri) -> Self::Future67 fn call(&mut self, dst: Uri) -> Self::Future { 68 let connect_timeout = self.connect_timeout; 69 let read_timeout = self.read_timeout; 70 let write_timeout = self.write_timeout; 71 let connecting = self.connector.call(dst); 72 73 let fut = async move { 74 let stream = match connect_timeout { 75 None => { 76 let io = connecting.await.map_err(Into::into)?; 77 TimeoutStream::new(io) 78 } 79 Some(connect_timeout) => { 80 let timeout = timeout(connect_timeout, connecting); 81 let connecting = timeout 82 .await 83 .map_err(|e| io::Error::new(io::ErrorKind::TimedOut, e))?; 84 let io = connecting.map_err(Into::into)?; 85 TimeoutStream::new(io) 86 } 87 }; 88 89 let mut tm = TimeoutConnectorStream::new(stream); 90 tm.set_read_timeout(read_timeout); 91 tm.set_write_timeout(write_timeout); 92 Ok(Box::pin(tm)) 93 }; 94 95 Box::pin(fut) 96 } 97 } 98 99 impl<T> TimeoutConnector<T> { 100 /// Set the timeout for connecting to a URL. 101 /// 102 /// Default is no timeout. 103 #[inline] set_connect_timeout(&mut self, val: Option<Duration>)104 pub fn set_connect_timeout(&mut self, val: Option<Duration>) { 105 self.connect_timeout = val; 106 } 107 108 /// Set the timeout for the response. 109 /// 110 /// Default is no timeout. 111 #[inline] set_read_timeout(&mut self, val: Option<Duration>)112 pub fn set_read_timeout(&mut self, val: Option<Duration>) { 113 self.read_timeout = val; 114 } 115 116 /// Set the timeout for the request. 117 /// 118 /// Default is no timeout. 119 #[inline] set_write_timeout(&mut self, val: Option<Duration>)120 pub fn set_write_timeout(&mut self, val: Option<Duration>) { 121 self.write_timeout = val; 122 } 123 } 124 125 impl<T> Connection for TimeoutConnector<T> 126 where 127 T: AsyncRead + AsyncWrite + Connection + Service<Uri> + Send + Unpin, 128 T::Response: AsyncRead + AsyncWrite + Send + Unpin, 129 T::Future: Send + 'static, 130 T::Error: Into<BoxError>, 131 { connected(&self) -> Connected132 fn connected(&self) -> Connected { 133 self.connector.connected() 134 } 135 } 136 137 #[cfg(test)] 138 mod tests { 139 use std::error::Error; 140 use std::io; 141 use std::time::Duration; 142 143 use hyper::client::HttpConnector; 144 use hyper::Client; 145 146 use super::TimeoutConnector; 147 148 #[tokio::test] test_timeout_connector()149 async fn test_timeout_connector() { 150 // 10.255.255.1 is a not a routable IP address 151 let url = "http://10.255.255.1".parse().unwrap(); 152 153 let http = HttpConnector::new(); 154 let mut connector = TimeoutConnector::new(http); 155 connector.set_connect_timeout(Some(Duration::from_millis(1))); 156 157 let client = Client::builder().build::<_, hyper::Body>(connector); 158 159 let res = client.get(url).await; 160 161 match res { 162 Ok(_) => panic!("Expected a timeout"), 163 Err(e) => { 164 if let Some(io_e) = e.source().unwrap().downcast_ref::<io::Error>() { 165 assert_eq!(io_e.kind(), io::ErrorKind::TimedOut); 166 } else { 167 panic!("Expected timeout error"); 168 } 169 } 170 } 171 } 172 173 #[tokio::test] test_read_timeout()174 async fn test_read_timeout() { 175 let url = "http://example.com".parse().unwrap(); 176 177 let http = HttpConnector::new(); 178 let mut connector = TimeoutConnector::new(http); 179 // A 1 ms read timeout should be so short that we trigger a timeout error 180 connector.set_read_timeout(Some(Duration::from_millis(1))); 181 182 let client = Client::builder().build::<_, hyper::Body>(connector); 183 184 let res = client.get(url).await; 185 186 match res { 187 Ok(_) => panic!("Expected a timeout"), 188 Err(e) => { 189 if let Some(io_e) = e.source().unwrap().downcast_ref::<io::Error>() { 190 assert_eq!(io_e.kind(), io::ErrorKind::TimedOut); 191 } else { 192 panic!("Expected timeout error"); 193 } 194 } 195 } 196 } 197 } 198