1 use std::fmt;
2 use std::io::{self, IoSlice, IoSliceMut, Read, Write};
3 use std::net::{self, Shutdown, SocketAddr};
4 #[cfg(any(unix, target_os = "wasi"))]
5 use std::os::fd::{AsFd, AsRawFd, BorrowedFd, FromRawFd, IntoRawFd, OwnedFd, RawFd};
6 // TODO: once <https://github.com/rust-lang/rust/issues/126198> is fixed this
7 // can use `std::os::fd` and be merged with the above.
8 #[cfg(target_os = "hermit")]
9 use std::os::hermit::io::{AsFd, AsRawFd, BorrowedFd, FromRawFd, IntoRawFd, OwnedFd, RawFd};
10 #[cfg(windows)]
11 use std::os::windows::io::{
12     AsRawSocket, AsSocket, BorrowedSocket, FromRawSocket, IntoRawSocket, OwnedSocket, RawSocket,
13 };
14 
15 use crate::io_source::IoSource;
16 #[cfg(not(target_os = "wasi"))]
17 use crate::sys::tcp::{connect, new_for_addr};
18 use crate::{event, Interest, Registry, Token};
19 
20 /// A non-blocking TCP stream between a local socket and a remote socket.
21 ///
22 /// The socket will be closed when the value is dropped.
23 ///
24 /// # Examples
25 ///
26 #[cfg_attr(feature = "os-poll", doc = "```")]
27 #[cfg_attr(not(feature = "os-poll"), doc = "```ignore")]
28 /// # use std::net::{TcpListener, SocketAddr};
29 /// # use std::error::Error;
30 /// #
31 /// # fn main() -> Result<(), Box<dyn Error>> {
32 /// let address: SocketAddr = "127.0.0.1:0".parse()?;
33 /// let listener = TcpListener::bind(address)?;
34 /// use mio::{Events, Interest, Poll, Token};
35 /// use mio::net::TcpStream;
36 /// use std::time::Duration;
37 ///
38 /// let mut stream = TcpStream::connect(listener.local_addr()?)?;
39 ///
40 /// let mut poll = Poll::new()?;
41 /// let mut events = Events::with_capacity(128);
42 ///
43 /// // Register the socket with `Poll`
44 /// poll.registry().register(&mut stream, Token(0), Interest::WRITABLE)?;
45 ///
46 /// poll.poll(&mut events, Some(Duration::from_millis(100)))?;
47 ///
48 /// // The socket might be ready at this point
49 /// #     Ok(())
50 /// # }
51 /// ```
52 pub struct TcpStream {
53     inner: IoSource<net::TcpStream>,
54 }
55 
56 impl TcpStream {
57     /// Create a new TCP stream and issue a non-blocking connect to the
58     /// specified address.
59     ///
60     /// # Notes
61     ///
62     /// The returned `TcpStream` may not be connected (and thus usable), unlike
63     /// the API found in `std::net::TcpStream`. Because Mio issues a
64     /// *non-blocking* connect it will not block the thread and instead return
65     /// an unconnected `TcpStream`.
66     ///
67     /// Ensuring the returned stream is connected is surprisingly complex when
68     /// considering cross-platform support. Doing this properly should follow
69     /// the steps below, an example implementation can be found
70     /// [here](https://github.com/Thomasdezeeuw/heph/blob/0c4f1ab3eaf08bea1d65776528bfd6114c9f8374/src/net/tcp/stream.rs#L560-L622).
71     ///
72     ///  1. Call `TcpStream::connect`
73     ///  2. Register the returned stream with at least [write interest].
74     ///  3. Wait for a (writable) event.
75     ///  4. Check `TcpStream::take_error`. If it returns an error, then
76     ///     something went wrong. If it returns `Ok(None)`, then proceed to
77     ///     step 5.
78     ///  5. Check `TcpStream::peer_addr`. If it returns `libc::EINPROGRESS` or
79     ///     `ErrorKind::NotConnected` it means the stream is not yet connected,
80     ///     go back to step 3. If it returns an address it means the stream is
81     ///     connected, go to step 6. If another error is returned something
82     ///     went wrong.
83     ///  6. Now the stream can be used.
84     ///
85     /// This may return a `WouldBlock` in which case the socket connection
86     /// cannot be completed immediately, it usually means there are insufficient
87     /// entries in the routing cache.
88     ///
89     /// [write interest]: Interest::WRITABLE
90     #[cfg(not(target_os = "wasi"))]
connect(addr: SocketAddr) -> io::Result<TcpStream>91     pub fn connect(addr: SocketAddr) -> io::Result<TcpStream> {
92         let socket = new_for_addr(addr)?;
93         #[cfg(any(unix, target_os = "hermit"))]
94         let stream = unsafe { TcpStream::from_raw_fd(socket) };
95         #[cfg(windows)]
96         let stream = unsafe { TcpStream::from_raw_socket(socket as _) };
97         connect(&stream.inner, addr)?;
98         Ok(stream)
99     }
100 
101     /// Creates a new `TcpStream` from a standard `net::TcpStream`.
102     ///
103     /// This function is intended to be used to wrap a TCP stream from the
104     /// standard library in the Mio equivalent. The conversion assumes nothing
105     /// about the underlying stream; it is left up to the user to set it in
106     /// non-blocking mode.
107     ///
108     /// # Note
109     ///
110     /// The TCP stream here will not have `connect` called on it, so it
111     /// should already be connected via some other means (be it manually, or
112     /// the standard library).
from_std(stream: net::TcpStream) -> TcpStream113     pub fn from_std(stream: net::TcpStream) -> TcpStream {
114         TcpStream {
115             inner: IoSource::new(stream),
116         }
117     }
118 
119     /// Returns the socket address of the remote peer of this TCP connection.
peer_addr(&self) -> io::Result<SocketAddr>120     pub fn peer_addr(&self) -> io::Result<SocketAddr> {
121         self.inner.peer_addr()
122     }
123 
124     /// Returns the socket address of the local half of this TCP connection.
local_addr(&self) -> io::Result<SocketAddr>125     pub fn local_addr(&self) -> io::Result<SocketAddr> {
126         self.inner.local_addr()
127     }
128 
129     /// Shuts down the read, write, or both halves of this connection.
130     ///
131     /// This function will cause all pending and future I/O on the specified
132     /// portions to return immediately with an appropriate value (see the
133     /// documentation of `Shutdown`).
shutdown(&self, how: Shutdown) -> io::Result<()>134     pub fn shutdown(&self, how: Shutdown) -> io::Result<()> {
135         self.inner.shutdown(how)
136     }
137 
138     /// Sets the value of the `TCP_NODELAY` option on this socket.
139     ///
140     /// If set, this option disables the Nagle algorithm. This means that
141     /// segments are always sent as soon as possible, even if there is only a
142     /// small amount of data. When not set, data is buffered until there is a
143     /// sufficient amount to send out, thereby avoiding the frequent sending of
144     /// small packets.
145     ///
146     /// # Notes
147     ///
148     /// On Windows make sure the stream is connected before calling this method,
149     /// by receiving an (writable) event. Trying to set `nodelay` on an
150     /// unconnected `TcpStream` is unspecified behavior.
set_nodelay(&self, nodelay: bool) -> io::Result<()>151     pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
152         self.inner.set_nodelay(nodelay)
153     }
154 
155     /// Gets the value of the `TCP_NODELAY` option on this socket.
156     ///
157     /// For more information about this option, see [`set_nodelay`][link].
158     ///
159     /// [link]: #method.set_nodelay
160     ///
161     /// # Notes
162     ///
163     /// On Windows make sure the stream is connected before calling this method,
164     /// by receiving an (writable) event. Trying to get `nodelay` on an
165     /// unconnected `TcpStream` is unspecified behavior.
nodelay(&self) -> io::Result<bool>166     pub fn nodelay(&self) -> io::Result<bool> {
167         self.inner.nodelay()
168     }
169 
170     /// Sets the value for the `IP_TTL` option on this socket.
171     ///
172     /// This value sets the time-to-live field that is used in every packet sent
173     /// from this socket.
174     ///
175     /// # Notes
176     ///
177     /// On Windows make sure the stream is connected before calling this method,
178     /// by receiving an (writable) event. Trying to set `ttl` on an
179     /// unconnected `TcpStream` is unspecified behavior.
set_ttl(&self, ttl: u32) -> io::Result<()>180     pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
181         self.inner.set_ttl(ttl)
182     }
183 
184     /// Gets the value of the `IP_TTL` option for this socket.
185     ///
186     /// For more information about this option, see [`set_ttl`][link].
187     ///
188     /// # Notes
189     ///
190     /// On Windows make sure the stream is connected before calling this method,
191     /// by receiving an (writable) event. Trying to get `ttl` on an
192     /// unconnected `TcpStream` is unspecified behavior.
193     ///
194     /// [link]: #method.set_ttl
ttl(&self) -> io::Result<u32>195     pub fn ttl(&self) -> io::Result<u32> {
196         self.inner.ttl()
197     }
198 
199     /// Get the value of the `SO_ERROR` option on this socket.
200     ///
201     /// This will retrieve the stored error in the underlying socket, clearing
202     /// the field in the process. This can be useful for checking errors between
203     /// calls.
take_error(&self) -> io::Result<Option<io::Error>>204     pub fn take_error(&self) -> io::Result<Option<io::Error>> {
205         self.inner.take_error()
206     }
207 
208     /// Receives data on the socket from the remote address to which it is
209     /// connected, without removing that data from the queue. On success,
210     /// returns the number of bytes peeked.
211     ///
212     /// Successive calls return the same data. This is accomplished by passing
213     /// `MSG_PEEK` as a flag to the underlying recv system call.
peek(&self, buf: &mut [u8]) -> io::Result<usize>214     pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
215         self.inner.peek(buf)
216     }
217 
218     /// Execute an I/O operation ensuring that the socket receives more events
219     /// if it hits a [`WouldBlock`] error.
220     ///
221     /// # Notes
222     ///
223     /// This method is required to be called for **all** I/O operations to
224     /// ensure the user will receive events once the socket is ready again after
225     /// returning a [`WouldBlock`] error.
226     ///
227     /// [`WouldBlock`]: io::ErrorKind::WouldBlock
228     ///
229     /// # Examples
230     ///
231     #[cfg_attr(unix, doc = "```no_run")]
232     #[cfg_attr(windows, doc = "```ignore")]
233     /// # use std::error::Error;
234     /// #
235     /// # fn main() -> Result<(), Box<dyn Error>> {
236     /// use std::io;
237     /// #[cfg(any(unix, target_os = "wasi"))]
238     /// use std::os::fd::AsRawFd;
239     /// #[cfg(windows)]
240     /// use std::os::windows::io::AsRawSocket;
241     /// use mio::net::TcpStream;
242     ///
243     /// let address = "127.0.0.1:8080".parse().unwrap();
244     /// let stream = TcpStream::connect(address)?;
245     ///
246     /// // Wait until the stream is readable...
247     ///
248     /// // Read from the stream using a direct libc call, of course the
249     /// // `io::Read` implementation would be easier to use.
250     /// let mut buf = [0; 512];
251     /// let n = stream.try_io(|| {
252     ///     let buf_ptr = &mut buf as *mut _ as *mut _;
253     ///     #[cfg(unix)]
254     ///     let res = unsafe { libc::recv(stream.as_raw_fd(), buf_ptr, buf.len(), 0) };
255     ///     #[cfg(windows)]
256     ///     let res = unsafe { libc::recvfrom(stream.as_raw_socket() as usize, buf_ptr, buf.len() as i32, 0, std::ptr::null_mut(), std::ptr::null_mut()) };
257     ///     if res != -1 {
258     ///         Ok(res as usize)
259     ///     } else {
260     ///         // If EAGAIN or EWOULDBLOCK is set by libc::recv, the closure
261     ///         // should return `WouldBlock` error.
262     ///         Err(io::Error::last_os_error())
263     ///     }
264     /// })?;
265     /// eprintln!("read {} bytes", n);
266     /// # Ok(())
267     /// # }
268     /// ```
try_io<F, T>(&self, f: F) -> io::Result<T> where F: FnOnce() -> io::Result<T>,269     pub fn try_io<F, T>(&self, f: F) -> io::Result<T>
270     where
271         F: FnOnce() -> io::Result<T>,
272     {
273         self.inner.do_io(|_| f())
274     }
275 }
276 
277 impl Read for TcpStream {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>278     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
279         self.inner.do_io(|mut inner| inner.read(buf))
280     }
281 
read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize>282     fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
283         self.inner.do_io(|mut inner| inner.read_vectored(bufs))
284     }
285 }
286 
287 impl<'a> Read for &'a TcpStream {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>288     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
289         self.inner.do_io(|mut inner| inner.read(buf))
290     }
291 
read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize>292     fn read_vectored(&mut self, bufs: &mut [IoSliceMut<'_>]) -> io::Result<usize> {
293         self.inner.do_io(|mut inner| inner.read_vectored(bufs))
294     }
295 }
296 
297 impl Write for TcpStream {
write(&mut self, buf: &[u8]) -> io::Result<usize>298     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
299         self.inner.do_io(|mut inner| inner.write(buf))
300     }
301 
write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize>302     fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
303         self.inner.do_io(|mut inner| inner.write_vectored(bufs))
304     }
305 
flush(&mut self) -> io::Result<()>306     fn flush(&mut self) -> io::Result<()> {
307         self.inner.do_io(|mut inner| inner.flush())
308     }
309 }
310 
311 impl<'a> Write for &'a TcpStream {
write(&mut self, buf: &[u8]) -> io::Result<usize>312     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
313         self.inner.do_io(|mut inner| inner.write(buf))
314     }
315 
write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize>316     fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
317         self.inner.do_io(|mut inner| inner.write_vectored(bufs))
318     }
319 
flush(&mut self) -> io::Result<()>320     fn flush(&mut self) -> io::Result<()> {
321         self.inner.do_io(|mut inner| inner.flush())
322     }
323 }
324 
325 impl event::Source for TcpStream {
register( &mut self, registry: &Registry, token: Token, interests: Interest, ) -> io::Result<()>326     fn register(
327         &mut self,
328         registry: &Registry,
329         token: Token,
330         interests: Interest,
331     ) -> io::Result<()> {
332         self.inner.register(registry, token, interests)
333     }
334 
reregister( &mut self, registry: &Registry, token: Token, interests: Interest, ) -> io::Result<()>335     fn reregister(
336         &mut self,
337         registry: &Registry,
338         token: Token,
339         interests: Interest,
340     ) -> io::Result<()> {
341         self.inner.reregister(registry, token, interests)
342     }
343 
deregister(&mut self, registry: &Registry) -> io::Result<()>344     fn deregister(&mut self, registry: &Registry) -> io::Result<()> {
345         self.inner.deregister(registry)
346     }
347 }
348 
349 impl fmt::Debug for TcpStream {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result350     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
351         self.inner.fmt(f)
352     }
353 }
354 
355 #[cfg(any(unix, target_os = "hermit", target_os = "wasi"))]
356 impl IntoRawFd for TcpStream {
into_raw_fd(self) -> RawFd357     fn into_raw_fd(self) -> RawFd {
358         self.inner.into_inner().into_raw_fd()
359     }
360 }
361 
362 #[cfg(any(unix, target_os = "hermit", target_os = "wasi"))]
363 impl AsRawFd for TcpStream {
as_raw_fd(&self) -> RawFd364     fn as_raw_fd(&self) -> RawFd {
365         self.inner.as_raw_fd()
366     }
367 }
368 
369 #[cfg(any(unix, target_os = "hermit", target_os = "wasi"))]
370 impl FromRawFd for TcpStream {
371     /// Converts a `RawFd` to a `TcpStream`.
372     ///
373     /// # Notes
374     ///
375     /// The caller is responsible for ensuring that the socket is in
376     /// non-blocking mode.
from_raw_fd(fd: RawFd) -> TcpStream377     unsafe fn from_raw_fd(fd: RawFd) -> TcpStream {
378         TcpStream::from_std(FromRawFd::from_raw_fd(fd))
379     }
380 }
381 
382 #[cfg(any(unix, target_os = "hermit", target_os = "wasi"))]
383 impl From<TcpStream> for OwnedFd {
from(tcp_stream: TcpStream) -> Self384     fn from(tcp_stream: TcpStream) -> Self {
385         tcp_stream.inner.into_inner().into()
386     }
387 }
388 
389 #[cfg(any(unix, target_os = "hermit", target_os = "wasi"))]
390 impl AsFd for TcpStream {
as_fd(&self) -> BorrowedFd<'_>391     fn as_fd(&self) -> BorrowedFd<'_> {
392         self.inner.as_fd()
393     }
394 }
395 
396 #[cfg(any(unix, target_os = "hermit", target_os = "wasi"))]
397 impl From<OwnedFd> for TcpStream {
398     /// Converts a `RawFd` to a `TcpStream`.
399     ///
400     /// # Notes
401     ///
402     /// The caller is responsible for ensuring that the socket is in
403     /// non-blocking mode.
from(fd: OwnedFd) -> Self404     fn from(fd: OwnedFd) -> Self {
405         TcpStream::from_std(From::from(fd))
406     }
407 }
408 
409 #[cfg(windows)]
410 impl IntoRawSocket for TcpStream {
into_raw_socket(self) -> RawSocket411     fn into_raw_socket(self) -> RawSocket {
412         self.inner.into_inner().into_raw_socket()
413     }
414 }
415 
416 #[cfg(windows)]
417 impl AsRawSocket for TcpStream {
as_raw_socket(&self) -> RawSocket418     fn as_raw_socket(&self) -> RawSocket {
419         self.inner.as_raw_socket()
420     }
421 }
422 
423 #[cfg(windows)]
424 impl FromRawSocket for TcpStream {
425     /// Converts a `RawSocket` to a `TcpStream`.
426     ///
427     /// # Notes
428     ///
429     /// The caller is responsible for ensuring that the socket is in
430     /// non-blocking mode.
from_raw_socket(socket: RawSocket) -> TcpStream431     unsafe fn from_raw_socket(socket: RawSocket) -> TcpStream {
432         TcpStream::from_std(FromRawSocket::from_raw_socket(socket))
433     }
434 }
435 
436 #[cfg(windows)]
437 impl From<TcpStream> for OwnedSocket {
from(tcp_stream: TcpStream) -> Self438     fn from(tcp_stream: TcpStream) -> Self {
439         tcp_stream.inner.into_inner().into()
440     }
441 }
442 
443 #[cfg(windows)]
444 impl AsSocket for TcpStream {
as_socket(&self) -> BorrowedSocket<'_>445     fn as_socket(&self) -> BorrowedSocket<'_> {
446         self.inner.as_socket()
447     }
448 }
449 
450 #[cfg(windows)]
451 impl From<OwnedSocket> for TcpStream {
452     /// Converts a `RawSocket` to a `TcpStream`.
453     ///
454     /// # Notes
455     ///
456     /// The caller is responsible for ensuring that the socket is in
457     /// non-blocking mode.
from(socket: OwnedSocket) -> Self458     fn from(socket: OwnedSocket) -> Self {
459         TcpStream::from_std(From::from(socket))
460     }
461 }
462 
463 impl From<TcpStream> for net::TcpStream {
from(stream: TcpStream) -> Self464     fn from(stream: TcpStream) -> Self {
465         // Safety: This is safe since we are extracting the raw fd from a well-constructed
466         // mio::net::TcpStream which ensures that we actually pass in a valid file
467         // descriptor/socket
468         unsafe {
469             #[cfg(any(unix, target_os = "hermit", target_os = "wasi"))]
470             {
471                 net::TcpStream::from_raw_fd(stream.into_raw_fd())
472             }
473             #[cfg(windows)]
474             {
475                 net::TcpStream::from_raw_socket(stream.into_raw_socket())
476             }
477         }
478     }
479 }
480