xref: /aosp_15_r20/external/crosvm/third_party/vmm_vhost/src/connection.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
2 // SPDX-License-Identifier: Apache-2.0
3 
4 //! Common data structures for listener and connection.
5 
6 use std::fs::File;
7 use std::io::IoSliceMut;
8 use std::mem;
9 
10 use base::AsRawDescriptor;
11 use base::RawDescriptor;
12 use zerocopy::AsBytes;
13 use zerocopy::FromBytes;
14 
15 use crate::connection::Req;
16 use crate::message::FrontendReq;
17 use crate::message::*;
18 use crate::sys::PlatformConnection;
19 use crate::Error;
20 use crate::Result;
21 
22 /// Listener for accepting connections.
23 pub trait Listener: Sized {
24     /// Accept an incoming connection.
accept(&mut self) -> Result<Option<Connection<FrontendReq>>>25     fn accept(&mut self) -> Result<Option<Connection<FrontendReq>>>;
26 
27     /// Change blocking status on the listener.
set_nonblocking(&self, block: bool) -> Result<()>28     fn set_nonblocking(&self, block: bool) -> Result<()>;
29 }
30 
31 // Advance the internal cursor of the slices.
32 // This is same with a nightly API `IoSliceMut::advance_slices` but for `&mut [u8]`.
advance_slices_mut(bufs: &mut &mut [&mut [u8]], mut count: usize)33 fn advance_slices_mut(bufs: &mut &mut [&mut [u8]], mut count: usize) {
34     use std::mem::take;
35 
36     let mut idx = 0;
37     for b in bufs.iter() {
38         if count < b.len() {
39             break;
40         }
41         count -= b.len();
42         idx += 1;
43     }
44     *bufs = &mut take(bufs)[idx..];
45     if !bufs.is_empty() {
46         let slice = take(&mut bufs[0]);
47         let (_, remaining) = slice.split_at_mut(count);
48         bufs[0] = remaining;
49     }
50 }
51 
52 /// A vhost-user connection at a low abstraction level. Provides methods for sending and receiving
53 /// vhost-user message headers and bodies.
54 ///
55 /// Builds on top of `PlatformConnection`, which provides methods for sending and receiving raw
56 /// bytes and file descriptors (a thin cross-platform abstraction for unix domain sockets).
57 pub struct Connection<R: Req>(
58     pub(crate) PlatformConnection,
59     pub(crate) std::marker::PhantomData<R>,
60     // Mark `Connection` as `!Sync` because message sends and recvs cannot safely be done
61     // concurrently.
62     pub(crate) std::marker::PhantomData<std::cell::Cell<()>>,
63 );
64 
65 impl<R: Req> Connection<R> {
66     /// Sends a header-only message with optional attached file descriptors.
send_header_only_message( &self, hdr: &VhostUserMsgHeader<R>, fds: Option<&[RawDescriptor]>, ) -> Result<()>67     pub fn send_header_only_message(
68         &self,
69         hdr: &VhostUserMsgHeader<R>,
70         fds: Option<&[RawDescriptor]>,
71     ) -> Result<()> {
72         self.0.send_message(hdr.as_bytes(), &[], &[], fds)
73     }
74 
75     /// Send a message with header and body. Optional file descriptors may be attached to
76     /// the message.
send_message<T: AsBytes>( &self, hdr: &VhostUserMsgHeader<R>, body: &T, fds: Option<&[RawDescriptor]>, ) -> Result<()>77     pub fn send_message<T: AsBytes>(
78         &self,
79         hdr: &VhostUserMsgHeader<R>,
80         body: &T,
81         fds: Option<&[RawDescriptor]>,
82     ) -> Result<()> {
83         self.0
84             .send_message(hdr.as_bytes(), body.as_bytes(), &[], fds)
85     }
86 
87     /// Send a message with header and body. `payload` is appended to the end of the body. Optional
88     /// file descriptors may also be attached to the message.
send_message_with_payload<T: Sized + AsBytes>( &self, hdr: &VhostUserMsgHeader<R>, body: &T, payload: &[u8], fds: Option<&[RawDescriptor]>, ) -> Result<()>89     pub fn send_message_with_payload<T: Sized + AsBytes>(
90         &self,
91         hdr: &VhostUserMsgHeader<R>,
92         body: &T,
93         payload: &[u8],
94         fds: Option<&[RawDescriptor]>,
95     ) -> Result<()> {
96         self.0
97             .send_message(hdr.as_bytes(), body.as_bytes(), payload, fds)
98     }
99 
100     /// Reads all bytes into the given scatter/gather vectors with optional attached files. Will
101     /// loop until all data has been transfered and errors if EOF is reached before then.
102     ///
103     /// # Return:
104     /// * - received fds on success
105     /// * - `Disconnect` - client is closed
106     ///
107     /// # TODO
108     /// This function takes a slice of `&mut [u8]` instead of `IoSliceMut` because the internal
109     /// cursor needs to be moved by `advance_slices_mut()`.
110     /// Once `IoSliceMut::advance_slices()` becomes stable, this should be updated.
111     /// <https://github.com/rust-lang/rust/issues/62726>.
recv_into_bufs_all(&self, mut bufs: &mut [&mut [u8]]) -> Result<Vec<File>>112     fn recv_into_bufs_all(&self, mut bufs: &mut [&mut [u8]]) -> Result<Vec<File>> {
113         let mut first_read = true;
114         let mut rfds = Vec::new();
115 
116         // Guarantee that `bufs` becomes empty if it doesn't contain any data.
117         advance_slices_mut(&mut bufs, 0);
118 
119         while !bufs.is_empty() {
120             let mut slices: Vec<IoSliceMut> = bufs.iter_mut().map(|b| IoSliceMut::new(b)).collect();
121             let res = self.0.recv_into_bufs(&mut slices, true);
122             match res {
123                 Ok((0, _)) => return Err(Error::PartialMessage),
124                 Ok((n, fds)) => {
125                     if first_read {
126                         first_read = false;
127                         if let Some(fds) = fds {
128                             rfds = fds;
129                         }
130                     }
131                     advance_slices_mut(&mut bufs, n);
132                 }
133                 Err(e) => match e {
134                     Error::SocketRetry(_) => {}
135                     _ => return Err(e),
136                 },
137             }
138         }
139         Ok(rfds)
140     }
141 
142     /// Receive message header
143     ///
144     /// Errors if the header is invalid.
145     ///
146     /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be accepted and all
147     /// other file descriptor will be discard silently.
recv_header(&self) -> Result<(VhostUserMsgHeader<R>, Vec<File>)>148     pub fn recv_header(&self) -> Result<(VhostUserMsgHeader<R>, Vec<File>)> {
149         let mut hdr = VhostUserMsgHeader::default();
150         let files = self.recv_into_bufs_all(&mut [hdr.as_bytes_mut()])?;
151         if !hdr.is_valid() {
152             return Err(Error::InvalidMessage);
153         }
154         Ok((hdr, files))
155     }
156 
157     /// Receive the body following the header `hdr`.
recv_body_bytes(&self, hdr: &VhostUserMsgHeader<R>) -> Result<Vec<u8>>158     pub fn recv_body_bytes(&self, hdr: &VhostUserMsgHeader<R>) -> Result<Vec<u8>> {
159         // NOTE: `recv_into_bufs_all` is a noop when the buffer is empty, so `hdr.get_size() == 0`
160         // works as expected.
161         let mut body = vec![0; hdr.get_size().try_into().unwrap()];
162         let files = self.recv_into_bufs_all(&mut [&mut body[..]])?;
163         if !files.is_empty() {
164             return Err(Error::InvalidMessage);
165         }
166         Ok(body)
167     }
168 
169     /// Receive a message header and body.
170     ///
171     /// Errors if the header or body is invalid.
172     ///
173     /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be
174     /// accepted and all other file descriptor will be discard silently.
recv_message<T: AsBytes + FromBytes + VhostUserMsgValidator>( &self, ) -> Result<(VhostUserMsgHeader<R>, T, Vec<File>)>175     pub fn recv_message<T: AsBytes + FromBytes + VhostUserMsgValidator>(
176         &self,
177     ) -> Result<(VhostUserMsgHeader<R>, T, Vec<File>)> {
178         let mut hdr = VhostUserMsgHeader::default();
179         let mut body = T::new_zeroed();
180         let mut slices = [hdr.as_bytes_mut(), body.as_bytes_mut()];
181         let files = self.recv_into_bufs_all(&mut slices)?;
182 
183         if !hdr.is_valid() || !body.is_valid() {
184             return Err(Error::InvalidMessage);
185         }
186 
187         Ok((hdr, body, files))
188     }
189 
190     /// Receive a message header and body, where the body includes a variable length payload at the
191     /// end.
192     ///
193     /// Errors if the header or body is invalid.
194     ///
195     /// Note, only the first MAX_ATTACHED_FD_ENTRIES file descriptors will be accepted and all
196     /// other file descriptor will be discard silently.
recv_message_with_payload<T: AsBytes + FromBytes + VhostUserMsgValidator>( &self, ) -> Result<(VhostUserMsgHeader<R>, T, Vec<u8>, Vec<File>)>197     pub fn recv_message_with_payload<T: AsBytes + FromBytes + VhostUserMsgValidator>(
198         &self,
199     ) -> Result<(VhostUserMsgHeader<R>, T, Vec<u8>, Vec<File>)> {
200         let (hdr, files) = self.recv_header()?;
201 
202         let mut body = T::new_zeroed();
203         let payload_size = hdr.get_size() as usize - mem::size_of::<T>();
204         let mut buf: Vec<u8> = vec![0; payload_size];
205         let mut slices = [body.as_bytes_mut(), buf.as_bytes_mut()];
206         let more_files = self.recv_into_bufs_all(&mut slices)?;
207         if !body.is_valid() || !more_files.is_empty() {
208             return Err(Error::InvalidMessage);
209         }
210 
211         Ok((hdr, body, buf, files))
212     }
213 }
214 
215 impl<R: Req> AsRawDescriptor for Connection<R> {
as_raw_descriptor(&self) -> RawDescriptor216     fn as_raw_descriptor(&self) -> RawDescriptor {
217         self.0.as_raw_descriptor()
218     }
219 }
220 
221 #[cfg(test)]
222 pub(crate) mod tests {
223     use std::io::Read;
224     use std::io::Seek;
225     use std::io::SeekFrom;
226     use std::io::Write;
227 
228     use tempfile::tempfile;
229 
230     use super::*;
231     use crate::message::VhostUserEmptyMessage;
232     use crate::message::VhostUserU64;
233 
234     #[test]
send_header_only()235     fn send_header_only() {
236         let (client_connection, server_connection) = Connection::pair().unwrap();
237         let hdr1 = VhostUserMsgHeader::new(FrontendReq::GET_FEATURES, 0, 0);
238         client_connection
239             .send_header_only_message(&hdr1, None)
240             .unwrap();
241         let (hdr2, _, files) = server_connection
242             .recv_message::<VhostUserEmptyMessage>()
243             .unwrap();
244         assert_eq!(hdr1, hdr2);
245         assert!(files.is_empty());
246     }
247 
248     #[test]
send_data()249     fn send_data() {
250         let (client_connection, server_connection) = Connection::pair().unwrap();
251         let hdr1 = VhostUserMsgHeader::new(FrontendReq::SET_FEATURES, 0, 8);
252         client_connection
253             .send_message(&hdr1, &VhostUserU64::new(0xf00dbeefdeadf00d), None)
254             .unwrap();
255         let (hdr2, body, files) = server_connection.recv_message::<VhostUserU64>().unwrap();
256         assert_eq!(hdr1, hdr2);
257         let value = body.value;
258         assert_eq!(value, 0xf00dbeefdeadf00d);
259         assert!(files.is_empty());
260     }
261 
262     #[test]
send_fd()263     fn send_fd() {
264         let (client_connection, server_connection) = Connection::pair().unwrap();
265 
266         let mut fd = tempfile().unwrap();
267         write!(fd, "test").unwrap();
268 
269         // Normal case for sending/receiving file descriptors
270         let hdr1 = VhostUserMsgHeader::new(FrontendReq::SET_MEM_TABLE, 0, 0);
271         client_connection
272             .send_header_only_message(&hdr1, Some(&[fd.as_raw_descriptor()]))
273             .unwrap();
274 
275         let (hdr2, _, files) = server_connection
276             .recv_message::<VhostUserEmptyMessage>()
277             .unwrap();
278         assert_eq!(hdr1, hdr2);
279         assert_eq!(files.len(), 1);
280         let mut file = &files[0];
281         let mut content = String::new();
282         file.seek(SeekFrom::Start(0)).unwrap();
283         file.read_to_string(&mut content).unwrap();
284         assert_eq!(content, "test");
285     }
286 
287     #[test]
test_advance_slices_mut()288     fn test_advance_slices_mut() {
289         // Test case from https://doc.rust-lang.org/std/io/struct.IoSliceMut.html#method.advance_slices
290         let mut buf1 = [1; 8];
291         let mut buf2 = [2; 16];
292         let mut buf3 = [3; 8];
293         let mut bufs = &mut [&mut buf1[..], &mut buf2[..], &mut buf3[..]][..];
294         advance_slices_mut(&mut bufs, 10);
295         assert_eq!(bufs[0], [2; 14].as_ref());
296         assert_eq!(bufs[1], [3; 8].as_ref());
297     }
298 }
299