1 use socket2::TcpKeepalive;
2 use std::fmt;
3 use std::future::Future;
4 use std::io;
5 use std::net::{SocketAddr, TcpListener as StdTcpListener};
6 use std::pin::Pin;
7 use std::task::{Context, Poll};
8 use std::time::Duration;
9
10 use tokio::net::TcpListener;
11 use tokio::time::Sleep;
12 use tracing::{debug, error, trace};
13
14 #[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411
15 pub use self::addr_stream::AddrStream;
16 use super::accept::Accept;
17
18 #[derive(Default, Debug, Clone, Copy)]
19 struct TcpKeepaliveConfig {
20 time: Option<Duration>,
21 interval: Option<Duration>,
22 retries: Option<u32>,
23 }
24
25 impl TcpKeepaliveConfig {
26 /// Converts into a `socket2::TcpKeealive` if there is any keep alive configuration.
into_socket2(self) -> Option<TcpKeepalive>27 fn into_socket2(self) -> Option<TcpKeepalive> {
28 let mut dirty = false;
29 let mut ka = TcpKeepalive::new();
30 if let Some(time) = self.time {
31 ka = ka.with_time(time);
32 dirty = true
33 }
34 if let Some(interval) = self.interval {
35 ka = Self::ka_with_interval(ka, interval, &mut dirty)
36 };
37 if let Some(retries) = self.retries {
38 ka = Self::ka_with_retries(ka, retries, &mut dirty)
39 };
40 if dirty {
41 Some(ka)
42 } else {
43 None
44 }
45 }
46
47 #[cfg(any(
48 target_os = "android",
49 target_os = "dragonfly",
50 target_os = "freebsd",
51 target_os = "fuchsia",
52 target_os = "illumos",
53 target_os = "linux",
54 target_os = "netbsd",
55 target_vendor = "apple",
56 windows,
57 ))]
ka_with_interval(ka: TcpKeepalive, interval: Duration, dirty: &mut bool) -> TcpKeepalive58 fn ka_with_interval(ka: TcpKeepalive, interval: Duration, dirty: &mut bool) -> TcpKeepalive {
59 *dirty = true;
60 ka.with_interval(interval)
61 }
62
63 #[cfg(not(any(
64 target_os = "android",
65 target_os = "dragonfly",
66 target_os = "freebsd",
67 target_os = "fuchsia",
68 target_os = "illumos",
69 target_os = "linux",
70 target_os = "netbsd",
71 target_vendor = "apple",
72 windows,
73 )))]
ka_with_interval(ka: TcpKeepalive, _: Duration, _: &mut bool) -> TcpKeepalive74 fn ka_with_interval(ka: TcpKeepalive, _: Duration, _: &mut bool) -> TcpKeepalive {
75 ka // no-op as keepalive interval is not supported on this platform
76 }
77
78 #[cfg(any(
79 target_os = "android",
80 target_os = "dragonfly",
81 target_os = "freebsd",
82 target_os = "fuchsia",
83 target_os = "illumos",
84 target_os = "linux",
85 target_os = "netbsd",
86 target_vendor = "apple",
87 ))]
ka_with_retries(ka: TcpKeepalive, retries: u32, dirty: &mut bool) -> TcpKeepalive88 fn ka_with_retries(ka: TcpKeepalive, retries: u32, dirty: &mut bool) -> TcpKeepalive {
89 *dirty = true;
90 ka.with_retries(retries)
91 }
92
93 #[cfg(not(any(
94 target_os = "android",
95 target_os = "dragonfly",
96 target_os = "freebsd",
97 target_os = "fuchsia",
98 target_os = "illumos",
99 target_os = "linux",
100 target_os = "netbsd",
101 target_vendor = "apple",
102 )))]
ka_with_retries(ka: TcpKeepalive, _: u32, _: &mut bool) -> TcpKeepalive103 fn ka_with_retries(ka: TcpKeepalive, _: u32, _: &mut bool) -> TcpKeepalive {
104 ka // no-op as keepalive retries is not supported on this platform
105 }
106 }
107
108 /// A stream of connections from binding to an address.
109 #[must_use = "streams do nothing unless polled"]
110 pub struct AddrIncoming {
111 addr: SocketAddr,
112 listener: TcpListener,
113 sleep_on_errors: bool,
114 tcp_keepalive_config: TcpKeepaliveConfig,
115 tcp_nodelay: bool,
116 timeout: Option<Pin<Box<Sleep>>>,
117 }
118
119 impl AddrIncoming {
new(addr: &SocketAddr) -> crate::Result<Self>120 pub(super) fn new(addr: &SocketAddr) -> crate::Result<Self> {
121 let std_listener = StdTcpListener::bind(addr).map_err(crate::Error::new_listen)?;
122
123 AddrIncoming::from_std(std_listener)
124 }
125
from_std(std_listener: StdTcpListener) -> crate::Result<Self>126 pub(super) fn from_std(std_listener: StdTcpListener) -> crate::Result<Self> {
127 // TcpListener::from_std doesn't set O_NONBLOCK
128 std_listener
129 .set_nonblocking(true)
130 .map_err(crate::Error::new_listen)?;
131 let listener = TcpListener::from_std(std_listener).map_err(crate::Error::new_listen)?;
132 AddrIncoming::from_listener(listener)
133 }
134
135 /// Creates a new `AddrIncoming` binding to provided socket address.
bind(addr: &SocketAddr) -> crate::Result<Self>136 pub fn bind(addr: &SocketAddr) -> crate::Result<Self> {
137 AddrIncoming::new(addr)
138 }
139
140 /// Creates a new `AddrIncoming` from an existing `tokio::net::TcpListener`.
from_listener(listener: TcpListener) -> crate::Result<Self>141 pub fn from_listener(listener: TcpListener) -> crate::Result<Self> {
142 let addr = listener.local_addr().map_err(crate::Error::new_listen)?;
143 Ok(AddrIncoming {
144 listener,
145 addr,
146 sleep_on_errors: true,
147 tcp_keepalive_config: TcpKeepaliveConfig::default(),
148 tcp_nodelay: false,
149 timeout: None,
150 })
151 }
152
153 /// Get the local address bound to this listener.
local_addr(&self) -> SocketAddr154 pub fn local_addr(&self) -> SocketAddr {
155 self.addr
156 }
157
158 /// Set the duration to remain idle before sending TCP keepalive probes.
159 ///
160 /// If `None` is specified, keepalive is disabled.
set_keepalive(&mut self, time: Option<Duration>) -> &mut Self161 pub fn set_keepalive(&mut self, time: Option<Duration>) -> &mut Self {
162 self.tcp_keepalive_config.time = time;
163 self
164 }
165
166 /// Set the duration between two successive TCP keepalive retransmissions,
167 /// if acknowledgement to the previous keepalive transmission is not received.
set_keepalive_interval(&mut self, interval: Option<Duration>) -> &mut Self168 pub fn set_keepalive_interval(&mut self, interval: Option<Duration>) -> &mut Self {
169 self.tcp_keepalive_config.interval = interval;
170 self
171 }
172
173 /// Set the number of retransmissions to be carried out before declaring that remote end is not available.
set_keepalive_retries(&mut self, retries: Option<u32>) -> &mut Self174 pub fn set_keepalive_retries(&mut self, retries: Option<u32>) -> &mut Self {
175 self.tcp_keepalive_config.retries = retries;
176 self
177 }
178
179 /// Set the value of `TCP_NODELAY` option for accepted connections.
set_nodelay(&mut self, enabled: bool) -> &mut Self180 pub fn set_nodelay(&mut self, enabled: bool) -> &mut Self {
181 self.tcp_nodelay = enabled;
182 self
183 }
184
185 /// Set whether to sleep on accept errors.
186 ///
187 /// A possible scenario is that the process has hit the max open files
188 /// allowed, and so trying to accept a new connection will fail with
189 /// `EMFILE`. In some cases, it's preferable to just wait for some time, if
190 /// the application will likely close some files (or connections), and try
191 /// to accept the connection again. If this option is `true`, the error
192 /// will be logged at the `error` level, since it is still a big deal,
193 /// and then the listener will sleep for 1 second.
194 ///
195 /// In other cases, hitting the max open files should be treat similarly
196 /// to being out-of-memory, and simply error (and shutdown). Setting
197 /// this option to `false` will allow that.
198 ///
199 /// Default is `true`.
set_sleep_on_errors(&mut self, val: bool)200 pub fn set_sleep_on_errors(&mut self, val: bool) {
201 self.sleep_on_errors = val;
202 }
203
poll_next_(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<AddrStream>>204 fn poll_next_(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<AddrStream>> {
205 // Check if a previous timeout is active that was set by IO errors.
206 if let Some(ref mut to) = self.timeout {
207 ready!(Pin::new(to).poll(cx));
208 }
209 self.timeout = None;
210
211 loop {
212 match ready!(self.listener.poll_accept(cx)) {
213 Ok((socket, remote_addr)) => {
214 if let Some(tcp_keepalive) = &self.tcp_keepalive_config.into_socket2() {
215 let sock_ref = socket2::SockRef::from(&socket);
216 if let Err(e) = sock_ref.set_tcp_keepalive(tcp_keepalive) {
217 trace!("error trying to set TCP keepalive: {}", e);
218 }
219 }
220 if let Err(e) = socket.set_nodelay(self.tcp_nodelay) {
221 trace!("error trying to set TCP nodelay: {}", e);
222 }
223 let local_addr = socket.local_addr()?;
224 return Poll::Ready(Ok(AddrStream::new(socket, remote_addr, local_addr)));
225 }
226 Err(e) => {
227 // Connection errors can be ignored directly, continue by
228 // accepting the next request.
229 if is_connection_error(&e) {
230 debug!("accepted connection already errored: {}", e);
231 continue;
232 }
233
234 if self.sleep_on_errors {
235 error!("accept error: {}", e);
236
237 // Sleep 1s.
238 let mut timeout = Box::pin(tokio::time::sleep(Duration::from_secs(1)));
239
240 match timeout.as_mut().poll(cx) {
241 Poll::Ready(()) => {
242 // Wow, it's been a second already? Ok then...
243 continue;
244 }
245 Poll::Pending => {
246 self.timeout = Some(timeout);
247 return Poll::Pending;
248 }
249 }
250 } else {
251 return Poll::Ready(Err(e));
252 }
253 }
254 }
255 }
256 }
257 }
258
259 impl Accept for AddrIncoming {
260 type Conn = AddrStream;
261 type Error = io::Error;
262
poll_accept( mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll<Option<Result<Self::Conn, Self::Error>>>263 fn poll_accept(
264 mut self: Pin<&mut Self>,
265 cx: &mut Context<'_>,
266 ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
267 let result = ready!(self.poll_next_(cx));
268 Poll::Ready(Some(result))
269 }
270 }
271
272 /// This function defines errors that are per-connection. Which basically
273 /// means that if we get this error from `accept()` system call it means
274 /// next connection might be ready to be accepted.
275 ///
276 /// All other errors will incur a timeout before next `accept()` is performed.
277 /// The timeout is useful to handle resource exhaustion errors like ENFILE
278 /// and EMFILE. Otherwise, could enter into tight loop.
is_connection_error(e: &io::Error) -> bool279 fn is_connection_error(e: &io::Error) -> bool {
280 matches!(
281 e.kind(),
282 io::ErrorKind::ConnectionRefused
283 | io::ErrorKind::ConnectionAborted
284 | io::ErrorKind::ConnectionReset
285 )
286 }
287
288 impl fmt::Debug for AddrIncoming {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result289 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
290 f.debug_struct("AddrIncoming")
291 .field("addr", &self.addr)
292 .field("sleep_on_errors", &self.sleep_on_errors)
293 .field("tcp_keepalive_config", &self.tcp_keepalive_config)
294 .field("tcp_nodelay", &self.tcp_nodelay)
295 .finish()
296 }
297 }
298
299 mod addr_stream {
300 use std::io;
301 use std::net::SocketAddr;
302 #[cfg(unix)]
303 use std::os::unix::io::{AsRawFd, RawFd};
304 use std::pin::Pin;
305 use std::task::{Context, Poll};
306 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
307 use tokio::net::TcpStream;
308
309 pin_project_lite::pin_project! {
310 /// A transport returned yieled by `AddrIncoming`.
311 #[derive(Debug)]
312 pub struct AddrStream {
313 #[pin]
314 inner: TcpStream,
315 pub(super) remote_addr: SocketAddr,
316 pub(super) local_addr: SocketAddr
317 }
318 }
319
320 impl AddrStream {
new( tcp: TcpStream, remote_addr: SocketAddr, local_addr: SocketAddr, ) -> AddrStream321 pub(super) fn new(
322 tcp: TcpStream,
323 remote_addr: SocketAddr,
324 local_addr: SocketAddr,
325 ) -> AddrStream {
326 AddrStream {
327 inner: tcp,
328 remote_addr,
329 local_addr,
330 }
331 }
332
333 /// Returns the remote (peer) address of this connection.
334 #[inline]
remote_addr(&self) -> SocketAddr335 pub fn remote_addr(&self) -> SocketAddr {
336 self.remote_addr
337 }
338
339 /// Returns the local address of this connection.
340 #[inline]
local_addr(&self) -> SocketAddr341 pub fn local_addr(&self) -> SocketAddr {
342 self.local_addr
343 }
344
345 /// Consumes the AddrStream and returns the underlying IO object
346 #[inline]
into_inner(self) -> TcpStream347 pub fn into_inner(self) -> TcpStream {
348 self.inner
349 }
350
351 /// Attempt to receive data on the socket, without removing that data
352 /// from the queue, registering the current task for wakeup if data is
353 /// not yet available.
poll_peek( &mut self, cx: &mut Context<'_>, buf: &mut tokio::io::ReadBuf<'_>, ) -> Poll<io::Result<usize>>354 pub fn poll_peek(
355 &mut self,
356 cx: &mut Context<'_>,
357 buf: &mut tokio::io::ReadBuf<'_>,
358 ) -> Poll<io::Result<usize>> {
359 self.inner.poll_peek(cx, buf)
360 }
361 }
362
363 impl AsyncRead for AddrStream {
364 #[inline]
poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll<io::Result<()>>365 fn poll_read(
366 self: Pin<&mut Self>,
367 cx: &mut Context<'_>,
368 buf: &mut ReadBuf<'_>,
369 ) -> Poll<io::Result<()>> {
370 self.project().inner.poll_read(cx, buf)
371 }
372 }
373
374 impl AsyncWrite for AddrStream {
375 #[inline]
poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>376 fn poll_write(
377 self: Pin<&mut Self>,
378 cx: &mut Context<'_>,
379 buf: &[u8],
380 ) -> Poll<io::Result<usize>> {
381 self.project().inner.poll_write(cx, buf)
382 }
383
384 #[inline]
poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>], ) -> Poll<io::Result<usize>>385 fn poll_write_vectored(
386 self: Pin<&mut Self>,
387 cx: &mut Context<'_>,
388 bufs: &[io::IoSlice<'_>],
389 ) -> Poll<io::Result<usize>> {
390 self.project().inner.poll_write_vectored(cx, bufs)
391 }
392
393 #[inline]
poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>>394 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
395 // TCP flush is a noop
396 Poll::Ready(Ok(()))
397 }
398
399 #[inline]
poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>400 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
401 self.project().inner.poll_shutdown(cx)
402 }
403
404 #[inline]
is_write_vectored(&self) -> bool405 fn is_write_vectored(&self) -> bool {
406 // Note that since `self.inner` is a `TcpStream`, this could
407 // *probably* be hard-coded to return `true`...but it seems more
408 // correct to ask it anyway (maybe we're on some platform without
409 // scatter-gather IO?)
410 self.inner.is_write_vectored()
411 }
412 }
413
414 #[cfg(unix)]
415 impl AsRawFd for AddrStream {
as_raw_fd(&self) -> RawFd416 fn as_raw_fd(&self) -> RawFd {
417 self.inner.as_raw_fd()
418 }
419 }
420 }
421
422 #[cfg(test)]
423 mod tests {
424 use crate::server::tcp::TcpKeepaliveConfig;
425 use std::time::Duration;
426
427 #[test]
no_tcp_keepalive_config()428 fn no_tcp_keepalive_config() {
429 assert!(TcpKeepaliveConfig::default().into_socket2().is_none());
430 }
431
432 #[test]
tcp_keepalive_time_config()433 fn tcp_keepalive_time_config() {
434 let mut kac = TcpKeepaliveConfig::default();
435 kac.time = Some(Duration::from_secs(60));
436 if let Some(tcp_keepalive) = kac.into_socket2() {
437 assert!(format!("{tcp_keepalive:?}").contains("time: Some(60s)"));
438 } else {
439 panic!("test failed");
440 }
441 }
442
443 #[cfg(any(
444 target_os = "android",
445 target_os = "dragonfly",
446 target_os = "freebsd",
447 target_os = "fuchsia",
448 target_os = "illumos",
449 target_os = "linux",
450 target_os = "netbsd",
451 target_vendor = "apple",
452 windows,
453 ))]
454 #[test]
tcp_keepalive_interval_config()455 fn tcp_keepalive_interval_config() {
456 let mut kac = TcpKeepaliveConfig::default();
457 kac.interval = Some(Duration::from_secs(1));
458 if let Some(tcp_keepalive) = kac.into_socket2() {
459 assert!(format!("{tcp_keepalive:?}").contains("interval: Some(1s)"));
460 } else {
461 panic!("test failed");
462 }
463 }
464
465 #[cfg(any(
466 target_os = "android",
467 target_os = "dragonfly",
468 target_os = "freebsd",
469 target_os = "fuchsia",
470 target_os = "illumos",
471 target_os = "linux",
472 target_os = "netbsd",
473 target_vendor = "apple",
474 ))]
475 #[test]
tcp_keepalive_retries_config()476 fn tcp_keepalive_retries_config() {
477 let mut kac = TcpKeepaliveConfig::default();
478 kac.retries = Some(3);
479 if let Some(tcp_keepalive) = kac.into_socket2() {
480 assert!(format!("{tcp_keepalive:?}").contains("retries: Some(3)"));
481 } else {
482 panic!("test failed");
483 }
484 }
485 }
486