1 // Copyright 2024 The Android Open Source Project
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // Copied from ChromiumOS with relicensing:
16 // src/platform2/vm_tools/chunnel/src/stream.rs
17 
18 //! This module provides abstraction of various stream socket type.
19 
20 use std::fmt;
21 use std::io;
22 use std::net::TcpStream;
23 use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
24 use std::os::unix::net::UnixStream;
25 use std::result;
26 
27 use libc::{self, c_void, shutdown, EPIPE, SHUT_WR};
28 use vsock::VsockAddr;
29 use vsock::VsockStream;
30 
31 /// Parse a vsock SocketAddr from a string. vsock socket addresses are of the form
32 /// "vsock:cid:port".
parse_vsock_addr(addr: &str) -> result::Result<VsockAddr, io::Error>33 pub fn parse_vsock_addr(addr: &str) -> result::Result<VsockAddr, io::Error> {
34     let components: Vec<&str> = addr.split(':').collect();
35     if components.len() != 3 || components[0] != "vsock" {
36         return Err(io::Error::from_raw_os_error(libc::EINVAL));
37     }
38 
39     Ok(VsockAddr::new(
40         components[1].parse().map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?,
41         components[2].parse().map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?,
42     ))
43 }
44 
45 /// StreamSocket provides a generic abstraction around any connection-oriented stream socket.
46 /// The socket will be closed when StreamSocket is dropped, but writes to the socket can also
47 /// be shut down manually.
48 pub struct StreamSocket {
49     fd: RawFd,
50     shut_down: bool,
51 }
52 
53 impl StreamSocket {
54     /// Connects to the given socket address. Supported socket types are vsock, unix, and TCP.
connect(sockaddr: &str) -> result::Result<StreamSocket, StreamSocketError>55     pub fn connect(sockaddr: &str) -> result::Result<StreamSocket, StreamSocketError> {
56         const UNIX_PREFIX: &str = "unix:";
57         const VSOCK_PREFIX: &str = "vsock:";
58 
59         if sockaddr.starts_with(VSOCK_PREFIX) {
60             let addr = parse_vsock_addr(sockaddr)
61                 .map_err(|e| StreamSocketError::ConnectVsock(sockaddr.to_string(), e))?;
62             let vsock_stream = VsockStream::connect(&addr)
63                 .map_err(|e| StreamSocketError::ConnectVsock(sockaddr.to_string(), e))?;
64             Ok(vsock_stream.into())
65         } else if sockaddr.starts_with(UNIX_PREFIX) {
66             let (_prefix, sock_path) = sockaddr.split_at(UNIX_PREFIX.len());
67             let unix_stream = UnixStream::connect(sock_path)
68                 .map_err(|e| StreamSocketError::ConnectUnix(sockaddr.to_string(), e))?;
69             Ok(unix_stream.into())
70         } else {
71             // Assume this is a TCP stream.
72             let tcp_stream = TcpStream::connect(sockaddr)
73                 .map_err(|e| StreamSocketError::ConnectTcp(sockaddr.to_string(), e))?;
74             Ok(tcp_stream.into())
75         }
76     }
77 
78     /// Shuts down writes to the socket using shutdown(2).
shut_down_write(&mut self) -> io::Result<()>79     pub fn shut_down_write(&mut self) -> io::Result<()> {
80         // SAFETY:
81         // Safe because no memory is modified and the return value is checked.
82         let ret = unsafe { shutdown(self.fd, SHUT_WR) };
83         if ret < 0 {
84             return Err(io::Error::last_os_error());
85         }
86 
87         self.shut_down = true;
88         Ok(())
89     }
90 
91     /// Returns true if the socket has been shut down for writes, false otherwise.
is_shut_down(&self) -> bool92     pub fn is_shut_down(&self) -> bool {
93         self.shut_down
94     }
95 }
96 
97 impl io::Read for StreamSocket {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>98     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
99         // SAFETY:
100         // Safe because this will only modify the contents of |buf| and we check the return value.
101         let ret = unsafe { libc::read(self.fd, buf.as_mut_ptr() as *mut c_void, buf.len()) };
102         if ret < 0 {
103             return Err(io::Error::last_os_error());
104         }
105 
106         Ok(ret as usize)
107     }
108 }
109 
110 impl io::Write for StreamSocket {
write(&mut self, buf: &[u8]) -> io::Result<usize>111     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
112         // SAFETY:
113         // Safe because this doesn't modify any memory and we check the return value.
114         let ret = unsafe { libc::write(self.fd, buf.as_ptr() as *const c_void, buf.len()) };
115         if ret < 0 {
116             // If a write causes EPIPE then the socket is shut down for writes.
117             let err = io::Error::last_os_error();
118             if let Some(errno) = err.raw_os_error() {
119                 if errno == EPIPE {
120                     self.shut_down = true
121                 }
122             }
123 
124             return Err(err);
125         }
126 
127         Ok(ret as usize)
128     }
129 
flush(&mut self) -> io::Result<()>130     fn flush(&mut self) -> io::Result<()> {
131         // No buffered data so nothing to do.
132         Ok(())
133     }
134 }
135 
136 impl AsRawFd for StreamSocket {
as_raw_fd(&self) -> RawFd137     fn as_raw_fd(&self) -> RawFd {
138         self.fd
139     }
140 }
141 
142 impl From<TcpStream> for StreamSocket {
from(stream: TcpStream) -> Self143     fn from(stream: TcpStream) -> Self {
144         StreamSocket { fd: stream.into_raw_fd(), shut_down: false }
145     }
146 }
147 
148 impl From<UnixStream> for StreamSocket {
from(stream: UnixStream) -> Self149     fn from(stream: UnixStream) -> Self {
150         StreamSocket { fd: stream.into_raw_fd(), shut_down: false }
151     }
152 }
153 
154 impl From<VsockStream> for StreamSocket {
from(stream: VsockStream) -> Self155     fn from(stream: VsockStream) -> Self {
156         StreamSocket { fd: stream.into_raw_fd(), shut_down: false }
157     }
158 }
159 
160 impl FromRawFd for StreamSocket {
from_raw_fd(fd: RawFd) -> Self161     unsafe fn from_raw_fd(fd: RawFd) -> Self {
162         StreamSocket { fd, shut_down: false }
163     }
164 }
165 
166 impl Drop for StreamSocket {
drop(&mut self)167     fn drop(&mut self) {
168         // SAFETY:
169         // Safe because this doesn't modify any memory and we are the only
170         // owner of the file descriptor.
171         unsafe { libc::close(self.fd) };
172     }
173 }
174 
175 /// Error enums for StreamSocket.
176 #[remain::sorted]
177 #[derive(Debug)]
178 pub enum StreamSocketError {
179     /// Error on connecting TCP socket.
180     ConnectTcp(String, io::Error),
181     /// Error on connecting unix socket.
182     ConnectUnix(String, io::Error),
183     /// Error on connecting vsock socket.
184     ConnectVsock(String, io::Error),
185 }
186 
187 impl fmt::Display for StreamSocketError {
188     #[remain::check]
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result189     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
190         use self::StreamSocketError::*;
191 
192         #[remain::sorted]
193         match self {
194             ConnectTcp(sockaddr, e) => {
195                 write!(f, "failed to connect to TCP sockaddr {}: {}", sockaddr, e)
196             }
197             ConnectUnix(sockaddr, e) => {
198                 write!(f, "failed to connect to unix sockaddr {}: {}", sockaddr, e)
199             }
200             ConnectVsock(sockaddr, e) => {
201                 write!(f, "failed to connect to vsock sockaddr {}: {}", sockaddr, e)
202             }
203         }
204     }
205 }
206 
207 #[cfg(test)]
208 mod tests {
209     use super::*;
210     use std::io::{Read, Write};
211     use std::net::TcpListener;
212     use std::os::unix::net::{UnixListener, UnixStream};
213     use tempfile::TempDir;
214 
215     #[test]
sock_connect_tcp()216     fn sock_connect_tcp() {
217         let listener = TcpListener::bind("127.0.0.1:0").unwrap();
218         let sockaddr = format!("127.0.0.1:{}", listener.local_addr().unwrap().port());
219 
220         let _stream = StreamSocket::connect(&sockaddr).unwrap();
221     }
222 
223     #[test]
sock_connect_unix()224     fn sock_connect_unix() {
225         let tempdir = TempDir::new().unwrap();
226         let path = tempdir.path().to_owned().join("test.sock");
227         let _listener = UnixListener::bind(&path).unwrap();
228 
229         let unix_addr = format!("unix:{}", path.to_str().unwrap());
230         let _stream = StreamSocket::connect(&unix_addr).unwrap();
231     }
232 
233     #[test]
invalid_sockaddr()234     fn invalid_sockaddr() {
235         assert!(StreamSocket::connect("this is not a valid sockaddr").is_err());
236     }
237 
238     #[test]
shut_down_write()239     fn shut_down_write() {
240         let (unix_stream, _dummy) = UnixStream::pair().unwrap();
241         let mut stream: StreamSocket = unix_stream.into();
242 
243         stream.write_all(b"hello").unwrap();
244 
245         stream.shut_down_write().unwrap();
246 
247         assert!(stream.is_shut_down());
248         assert!(stream.write(b"goodbye").is_err());
249     }
250 
251     #[test]
read_from_shut_down_sock()252     fn read_from_shut_down_sock() {
253         let (unix_stream1, unix_stream2) = UnixStream::pair().unwrap();
254         let mut stream1: StreamSocket = unix_stream1.into();
255         let mut stream2: StreamSocket = unix_stream2.into();
256 
257         stream1.shut_down_write().unwrap();
258 
259         // Reads from the other end of the socket should now return EOF.
260         let mut buf = Vec::new();
261         assert_eq!(stream2.read_to_end(&mut buf).unwrap(), 0);
262     }
263 }
264