1 use socket2::TcpKeepalive;
2 use std::fmt;
3 use std::future::Future;
4 use std::io;
5 use std::net::{SocketAddr, TcpListener as StdTcpListener};
6 use std::pin::Pin;
7 use std::task::{Context, Poll};
8 use std::time::Duration;
9 
10 use tokio::net::TcpListener;
11 use tokio::time::Sleep;
12 use tracing::{debug, error, trace};
13 
14 #[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411
15 pub use self::addr_stream::AddrStream;
16 use super::accept::Accept;
17 
18 #[derive(Default, Debug, Clone, Copy)]
19 struct TcpKeepaliveConfig {
20     time: Option<Duration>,
21     interval: Option<Duration>,
22     retries: Option<u32>,
23 }
24 
25 impl TcpKeepaliveConfig {
26     /// Converts into a `socket2::TcpKeealive` if there is any keep alive configuration.
into_socket2(self) -> Option<TcpKeepalive>27     fn into_socket2(self) -> Option<TcpKeepalive> {
28         let mut dirty = false;
29         let mut ka = TcpKeepalive::new();
30         if let Some(time) = self.time {
31             ka = ka.with_time(time);
32             dirty = true
33         }
34         if let Some(interval) = self.interval {
35             ka = Self::ka_with_interval(ka, interval, &mut dirty)
36         };
37         if let Some(retries) = self.retries {
38             ka = Self::ka_with_retries(ka, retries, &mut dirty)
39         };
40         if dirty {
41             Some(ka)
42         } else {
43             None
44         }
45     }
46 
47     #[cfg(any(
48         target_os = "android",
49         target_os = "dragonfly",
50         target_os = "freebsd",
51         target_os = "fuchsia",
52         target_os = "illumos",
53         target_os = "linux",
54         target_os = "netbsd",
55         target_vendor = "apple",
56         windows,
57     ))]
ka_with_interval(ka: TcpKeepalive, interval: Duration, dirty: &mut bool) -> TcpKeepalive58     fn ka_with_interval(ka: TcpKeepalive, interval: Duration, dirty: &mut bool) -> TcpKeepalive {
59         *dirty = true;
60         ka.with_interval(interval)
61     }
62 
63     #[cfg(not(any(
64         target_os = "android",
65         target_os = "dragonfly",
66         target_os = "freebsd",
67         target_os = "fuchsia",
68         target_os = "illumos",
69         target_os = "linux",
70         target_os = "netbsd",
71         target_vendor = "apple",
72         windows,
73     )))]
ka_with_interval(ka: TcpKeepalive, _: Duration, _: &mut bool) -> TcpKeepalive74     fn ka_with_interval(ka: TcpKeepalive, _: Duration, _: &mut bool) -> TcpKeepalive {
75         ka // no-op as keepalive interval is not supported on this platform
76     }
77 
78     #[cfg(any(
79         target_os = "android",
80         target_os = "dragonfly",
81         target_os = "freebsd",
82         target_os = "fuchsia",
83         target_os = "illumos",
84         target_os = "linux",
85         target_os = "netbsd",
86         target_vendor = "apple",
87     ))]
ka_with_retries(ka: TcpKeepalive, retries: u32, dirty: &mut bool) -> TcpKeepalive88     fn ka_with_retries(ka: TcpKeepalive, retries: u32, dirty: &mut bool) -> TcpKeepalive {
89         *dirty = true;
90         ka.with_retries(retries)
91     }
92 
93     #[cfg(not(any(
94         target_os = "android",
95         target_os = "dragonfly",
96         target_os = "freebsd",
97         target_os = "fuchsia",
98         target_os = "illumos",
99         target_os = "linux",
100         target_os = "netbsd",
101         target_vendor = "apple",
102     )))]
ka_with_retries(ka: TcpKeepalive, _: u32, _: &mut bool) -> TcpKeepalive103     fn ka_with_retries(ka: TcpKeepalive, _: u32, _: &mut bool) -> TcpKeepalive {
104         ka // no-op as keepalive retries is not supported on this platform
105     }
106 }
107 
108 /// A stream of connections from binding to an address.
109 #[must_use = "streams do nothing unless polled"]
110 pub struct AddrIncoming {
111     addr: SocketAddr,
112     listener: TcpListener,
113     sleep_on_errors: bool,
114     tcp_keepalive_config: TcpKeepaliveConfig,
115     tcp_nodelay: bool,
116     timeout: Option<Pin<Box<Sleep>>>,
117 }
118 
119 impl AddrIncoming {
new(addr: &SocketAddr) -> crate::Result<Self>120     pub(super) fn new(addr: &SocketAddr) -> crate::Result<Self> {
121         let std_listener = StdTcpListener::bind(addr).map_err(crate::Error::new_listen)?;
122 
123         AddrIncoming::from_std(std_listener)
124     }
125 
from_std(std_listener: StdTcpListener) -> crate::Result<Self>126     pub(super) fn from_std(std_listener: StdTcpListener) -> crate::Result<Self> {
127         // TcpListener::from_std doesn't set O_NONBLOCK
128         std_listener
129             .set_nonblocking(true)
130             .map_err(crate::Error::new_listen)?;
131         let listener = TcpListener::from_std(std_listener).map_err(crate::Error::new_listen)?;
132         AddrIncoming::from_listener(listener)
133     }
134 
135     /// Creates a new `AddrIncoming` binding to provided socket address.
bind(addr: &SocketAddr) -> crate::Result<Self>136     pub fn bind(addr: &SocketAddr) -> crate::Result<Self> {
137         AddrIncoming::new(addr)
138     }
139 
140     /// Creates a new `AddrIncoming` from an existing `tokio::net::TcpListener`.
from_listener(listener: TcpListener) -> crate::Result<Self>141     pub fn from_listener(listener: TcpListener) -> crate::Result<Self> {
142         let addr = listener.local_addr().map_err(crate::Error::new_listen)?;
143         Ok(AddrIncoming {
144             listener,
145             addr,
146             sleep_on_errors: true,
147             tcp_keepalive_config: TcpKeepaliveConfig::default(),
148             tcp_nodelay: false,
149             timeout: None,
150         })
151     }
152 
153     /// Get the local address bound to this listener.
local_addr(&self) -> SocketAddr154     pub fn local_addr(&self) -> SocketAddr {
155         self.addr
156     }
157 
158     /// Set the duration to remain idle before sending TCP keepalive probes.
159     ///
160     /// If `None` is specified, keepalive is disabled.
set_keepalive(&mut self, time: Option<Duration>) -> &mut Self161     pub fn set_keepalive(&mut self, time: Option<Duration>) -> &mut Self {
162         self.tcp_keepalive_config.time = time;
163         self
164     }
165 
166     /// Set the duration between two successive TCP keepalive retransmissions,
167     /// if acknowledgement to the previous keepalive transmission is not received.
set_keepalive_interval(&mut self, interval: Option<Duration>) -> &mut Self168     pub fn set_keepalive_interval(&mut self, interval: Option<Duration>) -> &mut Self {
169         self.tcp_keepalive_config.interval = interval;
170         self
171     }
172 
173     /// Set the number of retransmissions to be carried out before declaring that remote end is not available.
set_keepalive_retries(&mut self, retries: Option<u32>) -> &mut Self174     pub fn set_keepalive_retries(&mut self, retries: Option<u32>) -> &mut Self {
175         self.tcp_keepalive_config.retries = retries;
176         self
177     }
178 
179     /// Set the value of `TCP_NODELAY` option for accepted connections.
set_nodelay(&mut self, enabled: bool) -> &mut Self180     pub fn set_nodelay(&mut self, enabled: bool) -> &mut Self {
181         self.tcp_nodelay = enabled;
182         self
183     }
184 
185     /// Set whether to sleep on accept errors.
186     ///
187     /// A possible scenario is that the process has hit the max open files
188     /// allowed, and so trying to accept a new connection will fail with
189     /// `EMFILE`. In some cases, it's preferable to just wait for some time, if
190     /// the application will likely close some files (or connections), and try
191     /// to accept the connection again. If this option is `true`, the error
192     /// will be logged at the `error` level, since it is still a big deal,
193     /// and then the listener will sleep for 1 second.
194     ///
195     /// In other cases, hitting the max open files should be treat similarly
196     /// to being out-of-memory, and simply error (and shutdown). Setting
197     /// this option to `false` will allow that.
198     ///
199     /// Default is `true`.
set_sleep_on_errors(&mut self, val: bool)200     pub fn set_sleep_on_errors(&mut self, val: bool) {
201         self.sleep_on_errors = val;
202     }
203 
poll_next_(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<AddrStream>>204     fn poll_next_(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<AddrStream>> {
205         // Check if a previous timeout is active that was set by IO errors.
206         if let Some(ref mut to) = self.timeout {
207             ready!(Pin::new(to).poll(cx));
208         }
209         self.timeout = None;
210 
211         loop {
212             match ready!(self.listener.poll_accept(cx)) {
213                 Ok((socket, remote_addr)) => {
214                     if let Some(tcp_keepalive) = &self.tcp_keepalive_config.into_socket2() {
215                         let sock_ref = socket2::SockRef::from(&socket);
216                         if let Err(e) = sock_ref.set_tcp_keepalive(tcp_keepalive) {
217                             trace!("error trying to set TCP keepalive: {}", e);
218                         }
219                     }
220                     if let Err(e) = socket.set_nodelay(self.tcp_nodelay) {
221                         trace!("error trying to set TCP nodelay: {}", e);
222                     }
223                     let local_addr = socket.local_addr()?;
224                     return Poll::Ready(Ok(AddrStream::new(socket, remote_addr, local_addr)));
225                 }
226                 Err(e) => {
227                     // Connection errors can be ignored directly, continue by
228                     // accepting the next request.
229                     if is_connection_error(&e) {
230                         debug!("accepted connection already errored: {}", e);
231                         continue;
232                     }
233 
234                     if self.sleep_on_errors {
235                         error!("accept error: {}", e);
236 
237                         // Sleep 1s.
238                         let mut timeout = Box::pin(tokio::time::sleep(Duration::from_secs(1)));
239 
240                         match timeout.as_mut().poll(cx) {
241                             Poll::Ready(()) => {
242                                 // Wow, it's been a second already? Ok then...
243                                 continue;
244                             }
245                             Poll::Pending => {
246                                 self.timeout = Some(timeout);
247                                 return Poll::Pending;
248                             }
249                         }
250                     } else {
251                         return Poll::Ready(Err(e));
252                     }
253                 }
254             }
255         }
256     }
257 }
258 
259 impl Accept for AddrIncoming {
260     type Conn = AddrStream;
261     type Error = io::Error;
262 
poll_accept( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Option<Result<Self::Conn, Self::Error>>>263     fn poll_accept(
264         mut self: Pin<&mut Self>,
265         cx: &mut Context<'_>,
266     ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
267         let result = ready!(self.poll_next_(cx));
268         Poll::Ready(Some(result))
269     }
270 }
271 
272 /// This function defines errors that are per-connection. Which basically
273 /// means that if we get this error from `accept()` system call it means
274 /// next connection might be ready to be accepted.
275 ///
276 /// All other errors will incur a timeout before next `accept()` is performed.
277 /// The timeout is useful to handle resource exhaustion errors like ENFILE
278 /// and EMFILE. Otherwise, could enter into tight loop.
is_connection_error(e: &io::Error) -> bool279 fn is_connection_error(e: &io::Error) -> bool {
280     matches!(
281         e.kind(),
282         io::ErrorKind::ConnectionRefused
283             | io::ErrorKind::ConnectionAborted
284             | io::ErrorKind::ConnectionReset
285     )
286 }
287 
288 impl fmt::Debug for AddrIncoming {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result289     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
290         f.debug_struct("AddrIncoming")
291             .field("addr", &self.addr)
292             .field("sleep_on_errors", &self.sleep_on_errors)
293             .field("tcp_keepalive_config", &self.tcp_keepalive_config)
294             .field("tcp_nodelay", &self.tcp_nodelay)
295             .finish()
296     }
297 }
298 
299 mod addr_stream {
300     use std::io;
301     use std::net::SocketAddr;
302     #[cfg(unix)]
303     use std::os::unix::io::{AsRawFd, RawFd};
304     use std::pin::Pin;
305     use std::task::{Context, Poll};
306     use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
307     use tokio::net::TcpStream;
308 
309     pin_project_lite::pin_project! {
310         /// A transport returned yieled by `AddrIncoming`.
311         #[derive(Debug)]
312         pub struct AddrStream {
313             #[pin]
314             inner: TcpStream,
315             pub(super) remote_addr: SocketAddr,
316             pub(super) local_addr: SocketAddr
317         }
318     }
319 
320     impl AddrStream {
new( tcp: TcpStream, remote_addr: SocketAddr, local_addr: SocketAddr, ) -> AddrStream321         pub(super) fn new(
322             tcp: TcpStream,
323             remote_addr: SocketAddr,
324             local_addr: SocketAddr,
325         ) -> AddrStream {
326             AddrStream {
327                 inner: tcp,
328                 remote_addr,
329                 local_addr,
330             }
331         }
332 
333         /// Returns the remote (peer) address of this connection.
334         #[inline]
remote_addr(&self) -> SocketAddr335         pub fn remote_addr(&self) -> SocketAddr {
336             self.remote_addr
337         }
338 
339         /// Returns the local address of this connection.
340         #[inline]
local_addr(&self) -> SocketAddr341         pub fn local_addr(&self) -> SocketAddr {
342             self.local_addr
343         }
344 
345         /// Consumes the AddrStream and returns the underlying IO object
346         #[inline]
into_inner(self) -> TcpStream347         pub fn into_inner(self) -> TcpStream {
348             self.inner
349         }
350 
351         /// Attempt to receive data on the socket, without removing that data
352         /// from the queue, registering the current task for wakeup if data is
353         /// not yet available.
poll_peek( &mut self, cx: &mut Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, ) -> Poll<io::Result<usize>>354         pub fn poll_peek(
355             &mut self,
356             cx: &mut Context<'_>,
357             buf: &mut tokio::io::ReadBuf<'_>,
358         ) -> Poll<io::Result<usize>> {
359             self.inner.poll_peek(cx, buf)
360         }
361     }
362 
363     impl AsyncRead for AddrStream {
364         #[inline]
poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>365         fn poll_read(
366             self: Pin<&mut Self>,
367             cx: &mut Context<'_>,
368             buf: &mut ReadBuf<'_>,
369         ) -> Poll<io::Result<()>> {
370             self.project().inner.poll_read(cx, buf)
371         }
372     }
373 
374     impl AsyncWrite for AddrStream {
375         #[inline]
poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>376         fn poll_write(
377             self: Pin<&mut Self>,
378             cx: &mut Context<'_>,
379             buf: &[u8],
380         ) -> Poll<io::Result<usize>> {
381             self.project().inner.poll_write(cx, buf)
382         }
383 
384         #[inline]
poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>], ) -> Poll<io::Result<usize>>385         fn poll_write_vectored(
386             self: Pin<&mut Self>,
387             cx: &mut Context<'_>,
388             bufs: &[io::IoSlice<'_>],
389         ) -> Poll<io::Result<usize>> {
390             self.project().inner.poll_write_vectored(cx, bufs)
391         }
392 
393         #[inline]
poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>>394         fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
395             // TCP flush is a noop
396             Poll::Ready(Ok(()))
397         }
398 
399         #[inline]
poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>400         fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
401             self.project().inner.poll_shutdown(cx)
402         }
403 
404         #[inline]
is_write_vectored(&self) -> bool405         fn is_write_vectored(&self) -> bool {
406             // Note that since `self.inner` is a `TcpStream`, this could
407             // *probably* be hard-coded to return `true`...but it seems more
408             // correct to ask it anyway (maybe we're on some platform without
409             // scatter-gather IO?)
410             self.inner.is_write_vectored()
411         }
412     }
413 
414     #[cfg(unix)]
415     impl AsRawFd for AddrStream {
as_raw_fd(&self) -> RawFd416         fn as_raw_fd(&self) -> RawFd {
417             self.inner.as_raw_fd()
418         }
419     }
420 }
421 
422 #[cfg(test)]
423 mod tests {
424     use crate::server::tcp::TcpKeepaliveConfig;
425     use std::time::Duration;
426 
427     #[test]
no_tcp_keepalive_config()428     fn no_tcp_keepalive_config() {
429         assert!(TcpKeepaliveConfig::default().into_socket2().is_none());
430     }
431 
432     #[test]
tcp_keepalive_time_config()433     fn tcp_keepalive_time_config() {
434         let mut kac = TcpKeepaliveConfig::default();
435         kac.time = Some(Duration::from_secs(60));
436         if let Some(tcp_keepalive) = kac.into_socket2() {
437             assert!(format!("{tcp_keepalive:?}").contains("time: Some(60s)"));
438         } else {
439             panic!("test failed");
440         }
441     }
442 
443     #[cfg(any(
444         target_os = "android",
445         target_os = "dragonfly",
446         target_os = "freebsd",
447         target_os = "fuchsia",
448         target_os = "illumos",
449         target_os = "linux",
450         target_os = "netbsd",
451         target_vendor = "apple",
452         windows,
453     ))]
454     #[test]
tcp_keepalive_interval_config()455     fn tcp_keepalive_interval_config() {
456         let mut kac = TcpKeepaliveConfig::default();
457         kac.interval = Some(Duration::from_secs(1));
458         if let Some(tcp_keepalive) = kac.into_socket2() {
459             assert!(format!("{tcp_keepalive:?}").contains("interval: Some(1s)"));
460         } else {
461             panic!("test failed");
462         }
463     }
464 
465     #[cfg(any(
466         target_os = "android",
467         target_os = "dragonfly",
468         target_os = "freebsd",
469         target_os = "fuchsia",
470         target_os = "illumos",
471         target_os = "linux",
472         target_os = "netbsd",
473         target_vendor = "apple",
474     ))]
475     #[test]
tcp_keepalive_retries_config()476     fn tcp_keepalive_retries_config() {
477         let mut kac = TcpKeepaliveConfig::default();
478         kac.retries = Some(3);
479         if let Some(tcp_keepalive) = kac.into_socket2() {
480             assert!(format!("{tcp_keepalive:?}").contains("retries: Some(3)"));
481         } else {
482             panic!("test failed");
483         }
484     }
485 }
486