1 use std::error::Error as StdError;
2 use std::fmt;
3 use std::future::Future;
4 use std::io;
5 use std::marker::PhantomData;
6 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
7 use std::pin::Pin;
8 use std::sync::Arc;
9 use std::task::{self, Poll};
10 use std::time::Duration;
11 
12 use futures_util::future::Either;
13 use http::uri::{Scheme, Uri};
14 use pin_project_lite::pin_project;
15 use tokio::net::{TcpSocket, TcpStream};
16 use tokio::time::Sleep;
17 use tracing::{debug, trace, warn};
18 
19 use super::dns::{self, resolve, GaiResolver, Resolve};
20 use super::{Connected, Connection};
21 //#[cfg(feature = "runtime")] use super::dns::TokioThreadpoolGaiResolver;
22 
23 /// A connector for the `http` scheme.
24 ///
25 /// Performs DNS resolution in a thread pool, and then connects over TCP.
26 ///
27 /// # Note
28 ///
29 /// Sets the [`HttpInfo`](HttpInfo) value on responses, which includes
30 /// transport information such as the remote socket address used.
31 #[cfg_attr(docsrs, doc(cfg(feature = "tcp")))]
32 #[derive(Clone)]
33 pub struct HttpConnector<R = GaiResolver> {
34     config: Arc<Config>,
35     resolver: R,
36 }
37 
38 /// Extra information about the transport when an HttpConnector is used.
39 ///
40 /// # Example
41 ///
42 /// ```
43 /// # async fn doc() -> hyper::Result<()> {
44 /// use hyper::Uri;
45 /// use hyper::client::{Client, connect::HttpInfo};
46 ///
47 /// let client = Client::new();
48 /// let uri = Uri::from_static("http://example.com");
49 ///
50 /// let res = client.get(uri).await?;
51 /// res
52 ///     .extensions()
53 ///     .get::<HttpInfo>()
54 ///     .map(|info| {
55 ///         println!("remote addr = {}", info.remote_addr());
56 ///     });
57 /// # Ok(())
58 /// # }
59 /// ```
60 ///
61 /// # Note
62 ///
63 /// If a different connector is used besides [`HttpConnector`](HttpConnector),
64 /// this value will not exist in the extensions. Consult that specific
65 /// connector to see what "extra" information it might provide to responses.
66 #[derive(Clone, Debug)]
67 pub struct HttpInfo {
68     remote_addr: SocketAddr,
69     local_addr: SocketAddr,
70 }
71 
72 #[derive(Clone)]
73 struct Config {
74     connect_timeout: Option<Duration>,
75     enforce_http: bool,
76     happy_eyeballs_timeout: Option<Duration>,
77     keep_alive_timeout: Option<Duration>,
78     local_address_ipv4: Option<Ipv4Addr>,
79     local_address_ipv6: Option<Ipv6Addr>,
80     nodelay: bool,
81     reuse_address: bool,
82     send_buffer_size: Option<usize>,
83     recv_buffer_size: Option<usize>,
84 }
85 
86 // ===== impl HttpConnector =====
87 
88 impl HttpConnector {
89     /// Construct a new HttpConnector.
new() -> HttpConnector90     pub fn new() -> HttpConnector {
91         HttpConnector::new_with_resolver(GaiResolver::new())
92     }
93 }
94 
95 /*
96 #[cfg(feature = "runtime")]
97 impl HttpConnector<TokioThreadpoolGaiResolver> {
98     /// Construct a new HttpConnector using the `TokioThreadpoolGaiResolver`.
99     ///
100     /// This resolver **requires** the threadpool runtime to be used.
101     pub fn new_with_tokio_threadpool_resolver() -> Self {
102         HttpConnector::new_with_resolver(TokioThreadpoolGaiResolver::new())
103     }
104 }
105 */
106 
107 impl<R> HttpConnector<R> {
108     /// Construct a new HttpConnector.
109     ///
110     /// Takes a [`Resolver`](crate::client::connect::dns#resolvers-are-services) to handle DNS lookups.
new_with_resolver(resolver: R) -> HttpConnector<R>111     pub fn new_with_resolver(resolver: R) -> HttpConnector<R> {
112         HttpConnector {
113             config: Arc::new(Config {
114                 connect_timeout: None,
115                 enforce_http: true,
116                 happy_eyeballs_timeout: Some(Duration::from_millis(300)),
117                 keep_alive_timeout: None,
118                 local_address_ipv4: None,
119                 local_address_ipv6: None,
120                 nodelay: false,
121                 reuse_address: false,
122                 send_buffer_size: None,
123                 recv_buffer_size: None,
124             }),
125             resolver,
126         }
127     }
128 
129     /// Option to enforce all `Uri`s have the `http` scheme.
130     ///
131     /// Enabled by default.
132     #[inline]
enforce_http(&mut self, is_enforced: bool)133     pub fn enforce_http(&mut self, is_enforced: bool) {
134         self.config_mut().enforce_http = is_enforced;
135     }
136 
137     /// Set that all sockets have `SO_KEEPALIVE` set with the supplied duration.
138     ///
139     /// If `None`, the option will not be set.
140     ///
141     /// Default is `None`.
142     #[inline]
set_keepalive(&mut self, dur: Option<Duration>)143     pub fn set_keepalive(&mut self, dur: Option<Duration>) {
144         self.config_mut().keep_alive_timeout = dur;
145     }
146 
147     /// Set that all sockets have `SO_NODELAY` set to the supplied value `nodelay`.
148     ///
149     /// Default is `false`.
150     #[inline]
set_nodelay(&mut self, nodelay: bool)151     pub fn set_nodelay(&mut self, nodelay: bool) {
152         self.config_mut().nodelay = nodelay;
153     }
154 
155     /// Sets the value of the SO_SNDBUF option on the socket.
156     #[inline]
set_send_buffer_size(&mut self, size: Option<usize>)157     pub fn set_send_buffer_size(&mut self, size: Option<usize>) {
158         self.config_mut().send_buffer_size = size;
159     }
160 
161     /// Sets the value of the SO_RCVBUF option on the socket.
162     #[inline]
set_recv_buffer_size(&mut self, size: Option<usize>)163     pub fn set_recv_buffer_size(&mut self, size: Option<usize>) {
164         self.config_mut().recv_buffer_size = size;
165     }
166 
167     /// Set that all sockets are bound to the configured address before connection.
168     ///
169     /// If `None`, the sockets will not be bound.
170     ///
171     /// Default is `None`.
172     #[inline]
set_local_address(&mut self, addr: Option<IpAddr>)173     pub fn set_local_address(&mut self, addr: Option<IpAddr>) {
174         let (v4, v6) = match addr {
175             Some(IpAddr::V4(a)) => (Some(a), None),
176             Some(IpAddr::V6(a)) => (None, Some(a)),
177             _ => (None, None),
178         };
179 
180         let cfg = self.config_mut();
181 
182         cfg.local_address_ipv4 = v4;
183         cfg.local_address_ipv6 = v6;
184     }
185 
186     /// Set that all sockets are bound to the configured IPv4 or IPv6 address (depending on host's
187     /// preferences) before connection.
188     #[inline]
set_local_addresses(&mut self, addr_ipv4: Ipv4Addr, addr_ipv6: Ipv6Addr)189     pub fn set_local_addresses(&mut self, addr_ipv4: Ipv4Addr, addr_ipv6: Ipv6Addr) {
190         let cfg = self.config_mut();
191 
192         cfg.local_address_ipv4 = Some(addr_ipv4);
193         cfg.local_address_ipv6 = Some(addr_ipv6);
194     }
195 
196     /// Set the connect timeout.
197     ///
198     /// If a domain resolves to multiple IP addresses, the timeout will be
199     /// evenly divided across them.
200     ///
201     /// Default is `None`.
202     #[inline]
set_connect_timeout(&mut self, dur: Option<Duration>)203     pub fn set_connect_timeout(&mut self, dur: Option<Duration>) {
204         self.config_mut().connect_timeout = dur;
205     }
206 
207     /// Set timeout for [RFC 6555 (Happy Eyeballs)][RFC 6555] algorithm.
208     ///
209     /// If hostname resolves to both IPv4 and IPv6 addresses and connection
210     /// cannot be established using preferred address family before timeout
211     /// elapses, then connector will in parallel attempt connection using other
212     /// address family.
213     ///
214     /// If `None`, parallel connection attempts are disabled.
215     ///
216     /// Default is 300 milliseconds.
217     ///
218     /// [RFC 6555]: https://tools.ietf.org/html/rfc6555
219     #[inline]
set_happy_eyeballs_timeout(&mut self, dur: Option<Duration>)220     pub fn set_happy_eyeballs_timeout(&mut self, dur: Option<Duration>) {
221         self.config_mut().happy_eyeballs_timeout = dur;
222     }
223 
224     /// Set that all socket have `SO_REUSEADDR` set to the supplied value `reuse_address`.
225     ///
226     /// Default is `false`.
227     #[inline]
set_reuse_address(&mut self, reuse_address: bool) -> &mut Self228     pub fn set_reuse_address(&mut self, reuse_address: bool) -> &mut Self {
229         self.config_mut().reuse_address = reuse_address;
230         self
231     }
232 
233     // private
234 
config_mut(&mut self) -> &mut Config235     fn config_mut(&mut self) -> &mut Config {
236         // If the are HttpConnector clones, this will clone the inner
237         // config. So mutating the config won't ever affect previous
238         // clones.
239         Arc::make_mut(&mut self.config)
240     }
241 }
242 
243 static INVALID_NOT_HTTP: &str = "invalid URL, scheme is not http";
244 static INVALID_MISSING_SCHEME: &str = "invalid URL, scheme is missing";
245 static INVALID_MISSING_HOST: &str = "invalid URL, host is missing";
246 
247 // R: Debug required for now to allow adding it to debug output later...
248 impl<R: fmt::Debug> fmt::Debug for HttpConnector<R> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result249     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
250         f.debug_struct("HttpConnector").finish()
251     }
252 }
253 
254 impl<R> tower_service::Service<Uri> for HttpConnector<R>
255 where
256     R: Resolve + Clone + Send + Sync + 'static,
257     R::Future: Send,
258 {
259     type Response = TcpStream;
260     type Error = ConnectError;
261     type Future = HttpConnecting<R>;
262 
poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>>263     fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
264         ready!(self.resolver.poll_ready(cx)).map_err(ConnectError::dns)?;
265         Poll::Ready(Ok(()))
266     }
267 
call(&mut self, dst: Uri) -> Self::Future268     fn call(&mut self, dst: Uri) -> Self::Future {
269         let mut self_ = self.clone();
270         HttpConnecting {
271             fut: Box::pin(async move { self_.call_async(dst).await }),
272             _marker: PhantomData,
273         }
274     }
275 }
276 
get_host_port<'u>(config: &Config, dst: &'u Uri) -> Result<(&'u str, u16), ConnectError>277 fn get_host_port<'u>(config: &Config, dst: &'u Uri) -> Result<(&'u str, u16), ConnectError> {
278     trace!(
279         "Http::connect; scheme={:?}, host={:?}, port={:?}",
280         dst.scheme(),
281         dst.host(),
282         dst.port(),
283     );
284 
285     if config.enforce_http {
286         if dst.scheme() != Some(&Scheme::HTTP) {
287             return Err(ConnectError {
288                 msg: INVALID_NOT_HTTP.into(),
289                 cause: None,
290             });
291         }
292     } else if dst.scheme().is_none() {
293         return Err(ConnectError {
294             msg: INVALID_MISSING_SCHEME.into(),
295             cause: None,
296         });
297     }
298 
299     let host = match dst.host() {
300         Some(s) => s,
301         None => {
302             return Err(ConnectError {
303                 msg: INVALID_MISSING_HOST.into(),
304                 cause: None,
305             })
306         }
307     };
308     let port = match dst.port() {
309         Some(port) => port.as_u16(),
310         None => {
311             if dst.scheme() == Some(&Scheme::HTTPS) {
312                 443
313             } else {
314                 80
315             }
316         }
317     };
318 
319     Ok((host, port))
320 }
321 
322 impl<R> HttpConnector<R>
323 where
324     R: Resolve,
325 {
call_async(&mut self, dst: Uri) -> Result<TcpStream, ConnectError>326     async fn call_async(&mut self, dst: Uri) -> Result<TcpStream, ConnectError> {
327         let config = &self.config;
328 
329         let (host, port) = get_host_port(config, &dst)?;
330         let host = host.trim_start_matches('[').trim_end_matches(']');
331 
332         // If the host is already an IP addr (v4 or v6),
333         // skip resolving the dns and start connecting right away.
334         let addrs = if let Some(addrs) = dns::SocketAddrs::try_parse(host, port) {
335             addrs
336         } else {
337             let addrs = resolve(&mut self.resolver, dns::Name::new(host.into()))
338                 .await
339                 .map_err(ConnectError::dns)?;
340             let addrs = addrs
341                 .map(|mut addr| {
342                     addr.set_port(port);
343                     addr
344                 })
345                 .collect();
346             dns::SocketAddrs::new(addrs)
347         };
348 
349         let c = ConnectingTcp::new(addrs, config);
350 
351         let sock = c.connect().await?;
352 
353         if let Err(e) = sock.set_nodelay(config.nodelay) {
354             warn!("tcp set_nodelay error: {}", e);
355         }
356 
357         Ok(sock)
358     }
359 }
360 
361 impl Connection for TcpStream {
connected(&self) -> Connected362     fn connected(&self) -> Connected {
363         let connected = Connected::new();
364         if let (Ok(remote_addr), Ok(local_addr)) = (self.peer_addr(), self.local_addr()) {
365             connected.extra(HttpInfo {
366                 remote_addr,
367                 local_addr,
368             })
369         } else {
370             connected
371         }
372     }
373 }
374 
375 impl HttpInfo {
376     /// Get the remote address of the transport used.
remote_addr(&self) -> SocketAddr377     pub fn remote_addr(&self) -> SocketAddr {
378         self.remote_addr
379     }
380 
381     /// Get the local address of the transport used.
local_addr(&self) -> SocketAddr382     pub fn local_addr(&self) -> SocketAddr {
383         self.local_addr
384     }
385 }
386 
387 pin_project! {
388     // Not publicly exported (so missing_docs doesn't trigger).
389     //
390     // We return this `Future` instead of the `Pin<Box<dyn Future>>` directly
391     // so that users don't rely on it fitting in a `Pin<Box<dyn Future>>` slot
392     // (and thus we can change the type in the future).
393     #[must_use = "futures do nothing unless polled"]
394     #[allow(missing_debug_implementations)]
395     pub struct HttpConnecting<R> {
396         #[pin]
397         fut: BoxConnecting,
398         _marker: PhantomData<R>,
399     }
400 }
401 
402 type ConnectResult = Result<TcpStream, ConnectError>;
403 type BoxConnecting = Pin<Box<dyn Future<Output = ConnectResult> + Send>>;
404 
405 impl<R: Resolve> Future for HttpConnecting<R> {
406     type Output = ConnectResult;
407 
poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output>408     fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
409         self.project().fut.poll(cx)
410     }
411 }
412 
413 // Not publicly exported (so missing_docs doesn't trigger).
414 pub struct ConnectError {
415     msg: Box<str>,
416     cause: Option<Box<dyn StdError + Send + Sync>>,
417 }
418 
419 impl ConnectError {
new<S, E>(msg: S, cause: E) -> ConnectError where S: Into<Box<str>>, E: Into<Box<dyn StdError + Send + Sync>>,420     fn new<S, E>(msg: S, cause: E) -> ConnectError
421     where
422         S: Into<Box<str>>,
423         E: Into<Box<dyn StdError + Send + Sync>>,
424     {
425         ConnectError {
426             msg: msg.into(),
427             cause: Some(cause.into()),
428         }
429     }
430 
dns<E>(cause: E) -> ConnectError where E: Into<Box<dyn StdError + Send + Sync>>,431     fn dns<E>(cause: E) -> ConnectError
432     where
433         E: Into<Box<dyn StdError + Send + Sync>>,
434     {
435         ConnectError::new("dns error", cause)
436     }
437 
m<S, E>(msg: S) -> impl FnOnce(E) -> ConnectError where S: Into<Box<str>>, E: Into<Box<dyn StdError + Send + Sync>>,438     fn m<S, E>(msg: S) -> impl FnOnce(E) -> ConnectError
439     where
440         S: Into<Box<str>>,
441         E: Into<Box<dyn StdError + Send + Sync>>,
442     {
443         move |cause| ConnectError::new(msg, cause)
444     }
445 }
446 
447 impl fmt::Debug for ConnectError {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result448     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
449         if let Some(ref cause) = self.cause {
450             f.debug_tuple("ConnectError")
451                 .field(&self.msg)
452                 .field(cause)
453                 .finish()
454         } else {
455             self.msg.fmt(f)
456         }
457     }
458 }
459 
460 impl fmt::Display for ConnectError {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result461     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
462         f.write_str(&self.msg)?;
463 
464         if let Some(ref cause) = self.cause {
465             write!(f, ": {}", cause)?;
466         }
467 
468         Ok(())
469     }
470 }
471 
472 impl StdError for ConnectError {
source(&self) -> Option<&(dyn StdError + 'static)>473     fn source(&self) -> Option<&(dyn StdError + 'static)> {
474         self.cause.as_ref().map(|e| &**e as _)
475     }
476 }
477 
478 struct ConnectingTcp<'a> {
479     preferred: ConnectingTcpRemote,
480     fallback: Option<ConnectingTcpFallback>,
481     config: &'a Config,
482 }
483 
484 impl<'a> ConnectingTcp<'a> {
new(remote_addrs: dns::SocketAddrs, config: &'a Config) -> Self485     fn new(remote_addrs: dns::SocketAddrs, config: &'a Config) -> Self {
486         if let Some(fallback_timeout) = config.happy_eyeballs_timeout {
487             let (preferred_addrs, fallback_addrs) = remote_addrs
488                 .split_by_preference(config.local_address_ipv4, config.local_address_ipv6);
489             if fallback_addrs.is_empty() {
490                 return ConnectingTcp {
491                     preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout),
492                     fallback: None,
493                     config,
494                 };
495             }
496 
497             ConnectingTcp {
498                 preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout),
499                 fallback: Some(ConnectingTcpFallback {
500                     delay: tokio::time::sleep(fallback_timeout),
501                     remote: ConnectingTcpRemote::new(fallback_addrs, config.connect_timeout),
502                 }),
503                 config,
504             }
505         } else {
506             ConnectingTcp {
507                 preferred: ConnectingTcpRemote::new(remote_addrs, config.connect_timeout),
508                 fallback: None,
509                 config,
510             }
511         }
512     }
513 }
514 
515 struct ConnectingTcpFallback {
516     delay: Sleep,
517     remote: ConnectingTcpRemote,
518 }
519 
520 struct ConnectingTcpRemote {
521     addrs: dns::SocketAddrs,
522     connect_timeout: Option<Duration>,
523 }
524 
525 impl ConnectingTcpRemote {
new(addrs: dns::SocketAddrs, connect_timeout: Option<Duration>) -> Self526     fn new(addrs: dns::SocketAddrs, connect_timeout: Option<Duration>) -> Self {
527         let connect_timeout = connect_timeout
528             .map(|t| t.checked_div(addrs.len() as u32))
529             .flatten();
530 
531         Self {
532             addrs,
533             connect_timeout,
534         }
535     }
536 }
537 
538 impl ConnectingTcpRemote {
connect(&mut self, config: &Config) -> Result<TcpStream, ConnectError>539     async fn connect(&mut self, config: &Config) -> Result<TcpStream, ConnectError> {
540         let mut err = None;
541         for addr in &mut self.addrs {
542             debug!("connecting to {}", addr);
543             match connect(&addr, config, self.connect_timeout)?.await {
544                 Ok(tcp) => {
545                     debug!("connected to {}", addr);
546                     return Ok(tcp);
547                 }
548                 Err(e) => {
549                     trace!("connect error for {}: {:?}", addr, e);
550                     err = Some(e);
551                 }
552             }
553         }
554 
555         match err {
556             Some(e) => Err(e),
557             None => Err(ConnectError::new(
558                 "tcp connect error",
559                 std::io::Error::new(std::io::ErrorKind::NotConnected, "Network unreachable"),
560             )),
561         }
562     }
563 }
564 
bind_local_address( socket: &socket2::Socket, dst_addr: &SocketAddr, local_addr_ipv4: &Option<Ipv4Addr>, local_addr_ipv6: &Option<Ipv6Addr>, ) -> io::Result<()>565 fn bind_local_address(
566     socket: &socket2::Socket,
567     dst_addr: &SocketAddr,
568     local_addr_ipv4: &Option<Ipv4Addr>,
569     local_addr_ipv6: &Option<Ipv6Addr>,
570 ) -> io::Result<()> {
571     match (*dst_addr, local_addr_ipv4, local_addr_ipv6) {
572         (SocketAddr::V4(_), Some(addr), _) => {
573             socket.bind(&SocketAddr::new(addr.clone().into(), 0).into())?;
574         }
575         (SocketAddr::V6(_), _, Some(addr)) => {
576             socket.bind(&SocketAddr::new(addr.clone().into(), 0).into())?;
577         }
578         _ => {
579             if cfg!(windows) {
580                 // Windows requires a socket be bound before calling connect
581                 let any: SocketAddr = match *dst_addr {
582                     SocketAddr::V4(_) => ([0, 0, 0, 0], 0).into(),
583                     SocketAddr::V6(_) => ([0, 0, 0, 0, 0, 0, 0, 0], 0).into(),
584                 };
585                 socket.bind(&any.into())?;
586             }
587         }
588     }
589 
590     Ok(())
591 }
592 
connect( addr: &SocketAddr, config: &Config, connect_timeout: Option<Duration>, ) -> Result<impl Future<Output = Result<TcpStream, ConnectError>>, ConnectError>593 fn connect(
594     addr: &SocketAddr,
595     config: &Config,
596     connect_timeout: Option<Duration>,
597 ) -> Result<impl Future<Output = Result<TcpStream, ConnectError>>, ConnectError> {
598     // TODO(eliza): if Tokio's `TcpSocket` gains support for setting the
599     // keepalive timeout, it would be nice to use that instead of socket2,
600     // and avoid the unsafe `into_raw_fd`/`from_raw_fd` dance...
601     use socket2::{Domain, Protocol, Socket, TcpKeepalive, Type};
602     use std::convert::TryInto;
603 
604     let domain = Domain::for_address(*addr);
605     let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))
606         .map_err(ConnectError::m("tcp open error"))?;
607 
608     // When constructing a Tokio `TcpSocket` from a raw fd/socket, the user is
609     // responsible for ensuring O_NONBLOCK is set.
610     socket
611         .set_nonblocking(true)
612         .map_err(ConnectError::m("tcp set_nonblocking error"))?;
613 
614     if let Some(dur) = config.keep_alive_timeout {
615         let conf = TcpKeepalive::new().with_time(dur);
616         if let Err(e) = socket.set_tcp_keepalive(&conf) {
617             warn!("tcp set_keepalive error: {}", e);
618         }
619     }
620 
621     bind_local_address(
622         &socket,
623         addr,
624         &config.local_address_ipv4,
625         &config.local_address_ipv6,
626     )
627     .map_err(ConnectError::m("tcp bind local error"))?;
628 
629     #[cfg(unix)]
630     let socket = unsafe {
631         // Safety: `from_raw_fd` is only safe to call if ownership of the raw
632         // file descriptor is transferred. Since we call `into_raw_fd` on the
633         // socket2 socket, it gives up ownership of the fd and will not close
634         // it, so this is safe.
635         use std::os::unix::io::{FromRawFd, IntoRawFd};
636         TcpSocket::from_raw_fd(socket.into_raw_fd())
637     };
638     #[cfg(windows)]
639     let socket = unsafe {
640         // Safety: `from_raw_socket` is only safe to call if ownership of the raw
641         // Windows SOCKET is transferred. Since we call `into_raw_socket` on the
642         // socket2 socket, it gives up ownership of the SOCKET and will not close
643         // it, so this is safe.
644         use std::os::windows::io::{FromRawSocket, IntoRawSocket};
645         TcpSocket::from_raw_socket(socket.into_raw_socket())
646     };
647 
648     if config.reuse_address {
649         if let Err(e) = socket.set_reuseaddr(true) {
650             warn!("tcp set_reuse_address error: {}", e);
651         }
652     }
653 
654     if let Some(size) = config.send_buffer_size {
655         if let Err(e) = socket.set_send_buffer_size(size.try_into().unwrap_or(std::u32::MAX)) {
656             warn!("tcp set_buffer_size error: {}", e);
657         }
658     }
659 
660     if let Some(size) = config.recv_buffer_size {
661         if let Err(e) = socket.set_recv_buffer_size(size.try_into().unwrap_or(std::u32::MAX)) {
662             warn!("tcp set_recv_buffer_size error: {}", e);
663         }
664     }
665 
666     let connect = socket.connect(*addr);
667     Ok(async move {
668         match connect_timeout {
669             Some(dur) => match tokio::time::timeout(dur, connect).await {
670                 Ok(Ok(s)) => Ok(s),
671                 Ok(Err(e)) => Err(e),
672                 Err(e) => Err(io::Error::new(io::ErrorKind::TimedOut, e)),
673             },
674             None => connect.await,
675         }
676         .map_err(ConnectError::m("tcp connect error"))
677     })
678 }
679 
680 impl ConnectingTcp<'_> {
connect(mut self) -> Result<TcpStream, ConnectError>681     async fn connect(mut self) -> Result<TcpStream, ConnectError> {
682         match self.fallback {
683             None => self.preferred.connect(self.config).await,
684             Some(mut fallback) => {
685                 let preferred_fut = self.preferred.connect(self.config);
686                 futures_util::pin_mut!(preferred_fut);
687 
688                 let fallback_fut = fallback.remote.connect(self.config);
689                 futures_util::pin_mut!(fallback_fut);
690 
691                 let fallback_delay = fallback.delay;
692                 futures_util::pin_mut!(fallback_delay);
693 
694                 let (result, future) =
695                     match futures_util::future::select(preferred_fut, fallback_delay).await {
696                         Either::Left((result, _fallback_delay)) => {
697                             (result, Either::Right(fallback_fut))
698                         }
699                         Either::Right(((), preferred_fut)) => {
700                             // Delay is done, start polling both the preferred and the fallback
701                             futures_util::future::select(preferred_fut, fallback_fut)
702                                 .await
703                                 .factor_first()
704                         }
705                     };
706 
707                 if result.is_err() {
708                     // Fallback to the remaining future (could be preferred or fallback)
709                     // if we get an error
710                     future.await
711                 } else {
712                     result
713                 }
714             }
715         }
716     }
717 }
718 
719 #[cfg(test)]
720 mod tests {
721     use std::io;
722 
723     use ::http::Uri;
724 
725     use super::super::sealed::{Connect, ConnectSvc};
726     use super::{Config, ConnectError, HttpConnector};
727 
connect<C>( connector: C, dst: Uri, ) -> Result<<C::_Svc as ConnectSvc>::Connection, <C::_Svc as ConnectSvc>::Error> where C: Connect,728     async fn connect<C>(
729         connector: C,
730         dst: Uri,
731     ) -> Result<<C::_Svc as ConnectSvc>::Connection, <C::_Svc as ConnectSvc>::Error>
732     where
733         C: Connect,
734     {
735         connector.connect(super::super::sealed::Internal, dst).await
736     }
737 
738     #[tokio::test]
test_errors_enforce_http()739     async fn test_errors_enforce_http() {
740         let dst = "https://example.domain/foo/bar?baz".parse().unwrap();
741         let connector = HttpConnector::new();
742 
743         let err = connect(connector, dst).await.unwrap_err();
744         assert_eq!(&*err.msg, super::INVALID_NOT_HTTP);
745     }
746 
747     #[cfg(any(target_os = "linux", target_os = "macos"))]
get_local_ips() -> (Option<std::net::Ipv4Addr>, Option<std::net::Ipv6Addr>)748     fn get_local_ips() -> (Option<std::net::Ipv4Addr>, Option<std::net::Ipv6Addr>) {
749         use std::net::{IpAddr, TcpListener};
750 
751         let mut ip_v4 = None;
752         let mut ip_v6 = None;
753 
754         let ips = pnet_datalink::interfaces()
755             .into_iter()
756             .flat_map(|i| i.ips.into_iter().map(|n| n.ip()));
757 
758         for ip in ips {
759             match ip {
760                 IpAddr::V4(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v4 = Some(ip),
761                 IpAddr::V6(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v6 = Some(ip),
762                 _ => (),
763             }
764 
765             if ip_v4.is_some() && ip_v6.is_some() {
766                 break;
767             }
768         }
769 
770         (ip_v4, ip_v6)
771     }
772 
773     #[tokio::test]
test_errors_missing_scheme()774     async fn test_errors_missing_scheme() {
775         let dst = "example.domain".parse().unwrap();
776         let mut connector = HttpConnector::new();
777         connector.enforce_http(false);
778 
779         let err = connect(connector, dst).await.unwrap_err();
780         assert_eq!(&*err.msg, super::INVALID_MISSING_SCHEME);
781     }
782 
783     // NOTE: pnet crate that we use in this test doesn't compile on Windows
784     #[cfg(any(target_os = "linux", target_os = "macos"))]
785     #[tokio::test]
local_address()786     async fn local_address() {
787         use std::net::{IpAddr, TcpListener};
788         let _ = pretty_env_logger::try_init();
789 
790         let (bind_ip_v4, bind_ip_v6) = get_local_ips();
791         let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
792         let port = server4.local_addr().unwrap().port();
793         let server6 = TcpListener::bind(&format!("[::1]:{}", port)).unwrap();
794 
795         let assert_client_ip = |dst: String, server: TcpListener, expected_ip: IpAddr| async move {
796             let mut connector = HttpConnector::new();
797 
798             match (bind_ip_v4, bind_ip_v6) {
799                 (Some(v4), Some(v6)) => connector.set_local_addresses(v4, v6),
800                 (Some(v4), None) => connector.set_local_address(Some(v4.into())),
801                 (None, Some(v6)) => connector.set_local_address(Some(v6.into())),
802                 _ => unreachable!(),
803             }
804 
805             connect(connector, dst.parse().unwrap()).await.unwrap();
806 
807             let (_, client_addr) = server.accept().unwrap();
808 
809             assert_eq!(client_addr.ip(), expected_ip);
810         };
811 
812         if let Some(ip) = bind_ip_v4 {
813             assert_client_ip(format!("http://127.0.0.1:{}", port), server4, ip.into()).await;
814         }
815 
816         if let Some(ip) = bind_ip_v6 {
817             assert_client_ip(format!("http://[::1]:{}", port), server6, ip.into()).await;
818         }
819     }
820 
821     #[test]
822     #[cfg_attr(not(feature = "__internal_happy_eyeballs_tests"), ignore)]
client_happy_eyeballs()823     fn client_happy_eyeballs() {
824         use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, TcpListener};
825         use std::time::{Duration, Instant};
826 
827         use super::dns;
828         use super::ConnectingTcp;
829 
830         let _ = pretty_env_logger::try_init();
831         let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
832         let addr = server4.local_addr().unwrap();
833         let _server6 = TcpListener::bind(&format!("[::1]:{}", addr.port())).unwrap();
834         let rt = tokio::runtime::Builder::new_current_thread()
835             .enable_all()
836             .build()
837             .unwrap();
838 
839         let local_timeout = Duration::default();
840         let unreachable_v4_timeout = measure_connect(unreachable_ipv4_addr()).1;
841         let unreachable_v6_timeout = measure_connect(unreachable_ipv6_addr()).1;
842         let fallback_timeout = std::cmp::max(unreachable_v4_timeout, unreachable_v6_timeout)
843             + Duration::from_millis(250);
844 
845         let scenarios = &[
846             // Fast primary, without fallback.
847             (&[local_ipv4_addr()][..], 4, local_timeout, false),
848             (&[local_ipv6_addr()][..], 6, local_timeout, false),
849             // Fast primary, with (unused) fallback.
850             (
851                 &[local_ipv4_addr(), local_ipv6_addr()][..],
852                 4,
853                 local_timeout,
854                 false,
855             ),
856             (
857                 &[local_ipv6_addr(), local_ipv4_addr()][..],
858                 6,
859                 local_timeout,
860                 false,
861             ),
862             // Unreachable + fast primary, without fallback.
863             (
864                 &[unreachable_ipv4_addr(), local_ipv4_addr()][..],
865                 4,
866                 unreachable_v4_timeout,
867                 false,
868             ),
869             (
870                 &[unreachable_ipv6_addr(), local_ipv6_addr()][..],
871                 6,
872                 unreachable_v6_timeout,
873                 false,
874             ),
875             // Unreachable + fast primary, with (unused) fallback.
876             (
877                 &[
878                     unreachable_ipv4_addr(),
879                     local_ipv4_addr(),
880                     local_ipv6_addr(),
881                 ][..],
882                 4,
883                 unreachable_v4_timeout,
884                 false,
885             ),
886             (
887                 &[
888                     unreachable_ipv6_addr(),
889                     local_ipv6_addr(),
890                     local_ipv4_addr(),
891                 ][..],
892                 6,
893                 unreachable_v6_timeout,
894                 true,
895             ),
896             // Slow primary, with (used) fallback.
897             (
898                 &[slow_ipv4_addr(), local_ipv4_addr(), local_ipv6_addr()][..],
899                 6,
900                 fallback_timeout,
901                 false,
902             ),
903             (
904                 &[slow_ipv6_addr(), local_ipv6_addr(), local_ipv4_addr()][..],
905                 4,
906                 fallback_timeout,
907                 true,
908             ),
909             // Slow primary, with (used) unreachable + fast fallback.
910             (
911                 &[slow_ipv4_addr(), unreachable_ipv6_addr(), local_ipv6_addr()][..],
912                 6,
913                 fallback_timeout + unreachable_v6_timeout,
914                 false,
915             ),
916             (
917                 &[slow_ipv6_addr(), unreachable_ipv4_addr(), local_ipv4_addr()][..],
918                 4,
919                 fallback_timeout + unreachable_v4_timeout,
920                 true,
921             ),
922         ];
923 
924         // Scenarios for IPv6 -> IPv4 fallback require that host can access IPv6 network.
925         // Otherwise, connection to "slow" IPv6 address will error-out immediately.
926         let ipv6_accessible = measure_connect(slow_ipv6_addr()).0;
927 
928         for &(hosts, family, timeout, needs_ipv6_access) in scenarios {
929             if needs_ipv6_access && !ipv6_accessible {
930                 continue;
931             }
932 
933             let (start, stream) = rt
934                 .block_on(async move {
935                     let addrs = hosts
936                         .iter()
937                         .map(|host| (host.clone(), addr.port()).into())
938                         .collect();
939                     let cfg = Config {
940                         local_address_ipv4: None,
941                         local_address_ipv6: None,
942                         connect_timeout: None,
943                         keep_alive_timeout: None,
944                         happy_eyeballs_timeout: Some(fallback_timeout),
945                         nodelay: false,
946                         reuse_address: false,
947                         enforce_http: false,
948                         send_buffer_size: None,
949                         recv_buffer_size: None,
950                     };
951                     let connecting_tcp = ConnectingTcp::new(dns::SocketAddrs::new(addrs), &cfg);
952                     let start = Instant::now();
953                     Ok::<_, ConnectError>((start, ConnectingTcp::connect(connecting_tcp).await?))
954                 })
955                 .unwrap();
956             let res = if stream.peer_addr().unwrap().is_ipv4() {
957                 4
958             } else {
959                 6
960             };
961             let duration = start.elapsed();
962 
963             // Allow actual duration to be +/- 150ms off.
964             let min_duration = if timeout >= Duration::from_millis(150) {
965                 timeout - Duration::from_millis(150)
966             } else {
967                 Duration::default()
968             };
969             let max_duration = timeout + Duration::from_millis(150);
970 
971             assert_eq!(res, family);
972             assert!(duration >= min_duration);
973             assert!(duration <= max_duration);
974         }
975 
976         fn local_ipv4_addr() -> IpAddr {
977             Ipv4Addr::new(127, 0, 0, 1).into()
978         }
979 
980         fn local_ipv6_addr() -> IpAddr {
981             Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1).into()
982         }
983 
984         fn unreachable_ipv4_addr() -> IpAddr {
985             Ipv4Addr::new(127, 0, 0, 2).into()
986         }
987 
988         fn unreachable_ipv6_addr() -> IpAddr {
989             Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 2).into()
990         }
991 
992         fn slow_ipv4_addr() -> IpAddr {
993             // RFC 6890 reserved IPv4 address.
994             Ipv4Addr::new(198, 18, 0, 25).into()
995         }
996 
997         fn slow_ipv6_addr() -> IpAddr {
998             // RFC 6890 reserved IPv6 address.
999             Ipv6Addr::new(2001, 2, 0, 0, 0, 0, 0, 254).into()
1000         }
1001 
1002         fn measure_connect(addr: IpAddr) -> (bool, Duration) {
1003             let start = Instant::now();
1004             let result =
1005                 std::net::TcpStream::connect_timeout(&(addr, 80).into(), Duration::from_secs(1));
1006 
1007             let reachable = result.is_ok() || result.unwrap_err().kind() == io::ErrorKind::TimedOut;
1008             let duration = start.elapsed();
1009             (reachable, duration)
1010         }
1011     }
1012 }
1013