1 // Copyright (C) 2019-2021 Alibaba Cloud. All rights reserved.
2 // SPDX-License-Identifier: Apache-2.0
3 
4 use std::fs::File;
5 use std::mem;
6 use std::os::unix::io::{AsRawFd, RawFd};
7 use std::os::unix::net::UnixStream;
8 use std::sync::{Arc, Mutex};
9 
10 use super::connection::Endpoint;
11 use super::message::*;
12 use super::{Error, HandlerResult, Result};
13 
14 /// Define services provided by masters for the slave communication channel.
15 ///
16 /// The vhost-user specification defines a slave communication channel, by which slaves could
17 /// request services from masters. The [VhostUserMasterReqHandler] trait defines services provided
18 /// by masters, and it's used both on the master side and slave side.
19 /// - on the slave side, a stub forwarder implementing [VhostUserMasterReqHandler] will proxy
20 ///   service requests to masters. The [Slave] is an example stub forwarder.
21 /// - on the master side, the [MasterReqHandler] will forward service requests to a handler
22 ///   implementing [VhostUserMasterReqHandler].
23 ///
24 /// The [VhostUserMasterReqHandler] trait is design with interior mutability to improve performance
25 /// for multi-threading.
26 ///
27 /// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html
28 /// [MasterReqHandler]: struct.MasterReqHandler.html
29 /// [Slave]: struct.Slave.html
30 pub trait VhostUserMasterReqHandler {
31     /// Handle device configuration change notifications.
handle_config_change(&self) -> HandlerResult<u64>32     fn handle_config_change(&self) -> HandlerResult<u64> {
33         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
34     }
35 
36     /// Handle virtio-fs map file requests.
fs_slave_map(&self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd) -> HandlerResult<u64>37     fn fs_slave_map(&self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd) -> HandlerResult<u64> {
38         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
39     }
40 
41     /// Handle virtio-fs unmap file requests.
fs_slave_unmap(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>42     fn fs_slave_unmap(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
43         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
44     }
45 
46     /// Handle virtio-fs sync file requests.
fs_slave_sync(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>47     fn fs_slave_sync(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
48         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
49     }
50 
51     /// Handle virtio-fs file IO requests.
fs_slave_io(&self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd) -> HandlerResult<u64>52     fn fs_slave_io(&self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd) -> HandlerResult<u64> {
53         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
54     }
55 
56     // fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb);
57     // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: &dyn AsRawFd);
58 }
59 
60 /// A helper trait mirroring [VhostUserMasterReqHandler] but without interior mutability.
61 ///
62 /// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html
63 pub trait VhostUserMasterReqHandlerMut {
64     /// Handle device configuration change notifications.
handle_config_change(&mut self) -> HandlerResult<u64>65     fn handle_config_change(&mut self) -> HandlerResult<u64> {
66         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
67     }
68 
69     /// Handle virtio-fs map file requests.
fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd) -> HandlerResult<u64>70     fn fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd) -> HandlerResult<u64> {
71         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
72     }
73 
74     /// Handle virtio-fs unmap file requests.
fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>75     fn fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
76         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
77     }
78 
79     /// Handle virtio-fs sync file requests.
fs_slave_sync(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>80     fn fs_slave_sync(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
81         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
82     }
83 
84     /// Handle virtio-fs file IO requests.
fs_slave_io(&mut self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd) -> HandlerResult<u64>85     fn fs_slave_io(&mut self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd) -> HandlerResult<u64> {
86         Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
87     }
88 
89     // fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb);
90     // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: RawFd);
91 }
92 
93 impl<S: VhostUserMasterReqHandlerMut> VhostUserMasterReqHandler for Mutex<S> {
handle_config_change(&self) -> HandlerResult<u64>94     fn handle_config_change(&self) -> HandlerResult<u64> {
95         self.lock().unwrap().handle_config_change()
96     }
97 
fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: &dyn AsRawFd) -> HandlerResult<u64>98     fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: &dyn AsRawFd) -> HandlerResult<u64> {
99         self.lock().unwrap().fs_slave_map(fs, fd)
100     }
101 
fs_slave_unmap(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>102     fn fs_slave_unmap(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
103         self.lock().unwrap().fs_slave_unmap(fs)
104     }
105 
fs_slave_sync(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>106     fn fs_slave_sync(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
107         self.lock().unwrap().fs_slave_sync(fs)
108     }
109 
fs_slave_io(&self, fs: &VhostUserFSSlaveMsg, fd: &dyn AsRawFd) -> HandlerResult<u64>110     fn fs_slave_io(&self, fs: &VhostUserFSSlaveMsg, fd: &dyn AsRawFd) -> HandlerResult<u64> {
111         self.lock().unwrap().fs_slave_io(fs, fd)
112     }
113 }
114 
115 /// Server to handle service requests from slaves from the slave communication channel.
116 ///
117 /// The [MasterReqHandler] acts as a server on the master side, to handle service requests from
118 /// slaves on the slave communication channel. It's actually a proxy invoking the registered
119 /// handler implementing [VhostUserMasterReqHandler] to do the real work.
120 ///
121 /// [MasterReqHandler]: struct.MasterReqHandler.html
122 /// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html
123 pub struct MasterReqHandler<S: VhostUserMasterReqHandler> {
124     // underlying Unix domain socket for communication
125     sub_sock: Endpoint<SlaveReq>,
126     tx_sock: UnixStream,
127     // Protocol feature VHOST_USER_PROTOCOL_F_REPLY_ACK has been negotiated.
128     reply_ack_negotiated: bool,
129     // the VirtIO backend device object
130     backend: Arc<S>,
131     // whether the endpoint has encountered any failure
132     error: Option<i32>,
133 }
134 
135 impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> {
136     /// Create a server to handle service requests from slaves on the slave communication channel.
137     ///
138     /// This opens a pair of connected anonymous sockets to form the slave communication channel.
139     /// The socket fd returned by [Self::get_tx_raw_fd()] should be sent to the slave by
140     /// [VhostUserMaster::set_slave_request_fd()].
141     ///
142     /// [Self::get_tx_raw_fd()]: struct.MasterReqHandler.html#method.get_tx_raw_fd
143     /// [VhostUserMaster::set_slave_request_fd()]: trait.VhostUserMaster.html#tymethod.set_slave_request_fd
new(backend: Arc<S>) -> Result<Self>144     pub fn new(backend: Arc<S>) -> Result<Self> {
145         let (tx, rx) = UnixStream::pair().map_err(Error::SocketError)?;
146 
147         Ok(MasterReqHandler {
148             sub_sock: Endpoint::<SlaveReq>::from_stream(rx),
149             tx_sock: tx,
150             reply_ack_negotiated: false,
151             backend,
152             error: None,
153         })
154     }
155 
156     /// Get the socket fd for the slave to communication with the master.
157     ///
158     /// The returned fd should be sent to the slave by [VhostUserMaster::set_slave_request_fd()].
159     ///
160     /// [VhostUserMaster::set_slave_request_fd()]: trait.VhostUserMaster.html#tymethod.set_slave_request_fd
get_tx_raw_fd(&self) -> RawFd161     pub fn get_tx_raw_fd(&self) -> RawFd {
162         self.tx_sock.as_raw_fd()
163     }
164 
165     /// Set the negotiation state of the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature.
166     ///
167     /// When the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature has been negotiated,
168     /// the "REPLY_ACK" flag will be set in the message header for every slave to master request
169     /// message.
set_reply_ack_flag(&mut self, enable: bool)170     pub fn set_reply_ack_flag(&mut self, enable: bool) {
171         self.reply_ack_negotiated = enable;
172     }
173 
174     /// Mark endpoint as failed or in normal state.
set_failed(&mut self, error: i32)175     pub fn set_failed(&mut self, error: i32) {
176         if error == 0 {
177             self.error = None;
178         } else {
179             self.error = Some(error);
180         }
181     }
182 
183     /// Main entrance to server slave request from the slave communication channel.
184     ///
185     /// The caller needs to:
186     /// - serialize calls to this function
187     /// - decide what to do when errer happens
188     /// - optional recover from failure
handle_request(&mut self) -> Result<u64>189     pub fn handle_request(&mut self) -> Result<u64> {
190         // Return error if the endpoint is already in failed state.
191         self.check_state()?;
192 
193         // The underlying communication channel is a Unix domain socket in
194         // stream mode, and recvmsg() is a little tricky here. To successfully
195         // receive attached file descriptors, we need to receive messages and
196         // corresponding attached file descriptors in this way:
197         // . recv messsage header and optional attached file
198         // . validate message header
199         // . recv optional message body and payload according size field in
200         //   message header
201         // . validate message body and optional payload
202         let (hdr, files) = self.sub_sock.recv_header()?;
203         self.check_attached_files(&hdr, &files)?;
204         let (size, buf) = match hdr.get_size() {
205             0 => (0, vec![0u8; 0]),
206             len => {
207                 if len as usize > MAX_MSG_SIZE {
208                     return Err(Error::InvalidMessage);
209                 }
210                 let (size2, rbuf) = self.sub_sock.recv_data(len as usize)?;
211                 if size2 != len as usize {
212                     return Err(Error::InvalidMessage);
213                 }
214                 (size2, rbuf)
215             }
216         };
217 
218         let res = match hdr.get_code() {
219             Ok(SlaveReq::CONFIG_CHANGE_MSG) => {
220                 self.check_msg_size(&hdr, size, 0)?;
221                 self.backend
222                     .handle_config_change()
223                     .map_err(Error::ReqHandlerError)
224             }
225             Ok(SlaveReq::FS_MAP) => {
226                 let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
227                 // check_attached_files() has validated files
228                 self.backend
229                     .fs_slave_map(&msg, &files.unwrap()[0])
230                     .map_err(Error::ReqHandlerError)
231             }
232             Ok(SlaveReq::FS_UNMAP) => {
233                 let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
234                 self.backend
235                     .fs_slave_unmap(&msg)
236                     .map_err(Error::ReqHandlerError)
237             }
238             Ok(SlaveReq::FS_SYNC) => {
239                 let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
240                 self.backend
241                     .fs_slave_sync(&msg)
242                     .map_err(Error::ReqHandlerError)
243             }
244             Ok(SlaveReq::FS_IO) => {
245                 let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?;
246                 // check_attached_files() has validated files
247                 self.backend
248                     .fs_slave_io(&msg, &files.unwrap()[0])
249                     .map_err(Error::ReqHandlerError)
250             }
251             _ => Err(Error::InvalidMessage),
252         };
253 
254         self.send_ack_message(&hdr, &res)?;
255 
256         res
257     }
258 
check_state(&self) -> Result<()>259     fn check_state(&self) -> Result<()> {
260         match self.error {
261             Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))),
262             None => Ok(()),
263         }
264     }
265 
check_msg_size( &self, hdr: &VhostUserMsgHeader<SlaveReq>, size: usize, expected: usize, ) -> Result<()>266     fn check_msg_size(
267         &self,
268         hdr: &VhostUserMsgHeader<SlaveReq>,
269         size: usize,
270         expected: usize,
271     ) -> Result<()> {
272         if hdr.get_size() as usize != expected
273             || hdr.is_reply()
274             || hdr.get_version() != 0x1
275             || size != expected
276         {
277             return Err(Error::InvalidMessage);
278         }
279         Ok(())
280     }
281 
check_attached_files( &self, hdr: &VhostUserMsgHeader<SlaveReq>, files: &Option<Vec<File>>, ) -> Result<()>282     fn check_attached_files(
283         &self,
284         hdr: &VhostUserMsgHeader<SlaveReq>,
285         files: &Option<Vec<File>>,
286     ) -> Result<()> {
287         match hdr.get_code() {
288             Ok(SlaveReq::FS_MAP | SlaveReq::FS_IO) => {
289                 // Expect a single file is passed.
290                 match files {
291                     Some(files) if files.len() == 1 => Ok(()),
292                     _ => Err(Error::InvalidMessage),
293                 }
294             }
295             _ if files.is_some() => Err(Error::InvalidMessage),
296             _ => Ok(()),
297         }
298     }
299 
extract_msg_body<T: Sized + VhostUserMsgValidator>( &self, hdr: &VhostUserMsgHeader<SlaveReq>, size: usize, buf: &[u8], ) -> Result<T>300     fn extract_msg_body<T: Sized + VhostUserMsgValidator>(
301         &self,
302         hdr: &VhostUserMsgHeader<SlaveReq>,
303         size: usize,
304         buf: &[u8],
305     ) -> Result<T> {
306         self.check_msg_size(hdr, size, mem::size_of::<T>())?;
307         // SAFETY: Safe because we checked that `buf` size is equal to T size.
308         let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const T) };
309         if !msg.is_valid() {
310             return Err(Error::InvalidMessage);
311         }
312         Ok(msg)
313     }
314 
new_reply_header<T: Sized>( &self, req: &VhostUserMsgHeader<SlaveReq>, ) -> Result<VhostUserMsgHeader<SlaveReq>>315     fn new_reply_header<T: Sized>(
316         &self,
317         req: &VhostUserMsgHeader<SlaveReq>,
318     ) -> Result<VhostUserMsgHeader<SlaveReq>> {
319         if mem::size_of::<T>() > MAX_MSG_SIZE {
320             return Err(Error::InvalidParam);
321         }
322         self.check_state()?;
323         Ok(VhostUserMsgHeader::new(
324             req.get_code()?,
325             VhostUserHeaderFlag::REPLY.bits(),
326             mem::size_of::<T>() as u32,
327         ))
328     }
329 
send_ack_message( &mut self, req: &VhostUserMsgHeader<SlaveReq>, res: &Result<u64>, ) -> Result<()>330     fn send_ack_message(
331         &mut self,
332         req: &VhostUserMsgHeader<SlaveReq>,
333         res: &Result<u64>,
334     ) -> Result<()> {
335         if self.reply_ack_negotiated && req.is_need_reply() {
336             let hdr = self.new_reply_header::<VhostUserU64>(req)?;
337             let def_err = libc::EINVAL;
338             let val = match res {
339                 Ok(n) => *n,
340                 Err(e) => match e {
341                     Error::ReqHandlerError(ioerr) => match ioerr.raw_os_error() {
342                         Some(rawerr) => -rawerr as u64,
343                         None => -def_err as u64,
344                     },
345                     _ => -def_err as u64,
346                 },
347             };
348             let msg = VhostUserU64::new(val);
349             self.sub_sock.send_message(&hdr, &msg, None)?;
350         }
351         Ok(())
352     }
353 }
354 
355 impl<S: VhostUserMasterReqHandler> AsRawFd for MasterReqHandler<S> {
as_raw_fd(&self) -> RawFd356     fn as_raw_fd(&self) -> RawFd {
357         self.sub_sock.as_raw_fd()
358     }
359 }
360 
361 #[cfg(test)]
362 mod tests {
363     use super::*;
364 
365     #[cfg(feature = "vhost-user-slave")]
366     use crate::vhost_user::Slave;
367     #[cfg(feature = "vhost-user-slave")]
368     use std::os::unix::io::FromRawFd;
369 
370     struct MockMasterReqHandler {}
371 
372     impl VhostUserMasterReqHandlerMut for MockMasterReqHandler {
373         /// Handle virtio-fs map file requests from the slave.
fs_slave_map( &mut self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd, ) -> HandlerResult<u64>374         fn fs_slave_map(
375             &mut self,
376             _fs: &VhostUserFSSlaveMsg,
377             _fd: &dyn AsRawFd,
378         ) -> HandlerResult<u64> {
379             Ok(0)
380         }
381 
382         /// Handle virtio-fs unmap file requests from the slave.
fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>383         fn fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> {
384             Err(std::io::Error::from_raw_os_error(libc::ENOSYS))
385         }
386     }
387 
388     #[test]
test_new_master_req_handler()389     fn test_new_master_req_handler() {
390         let backend = Arc::new(Mutex::new(MockMasterReqHandler {}));
391         let mut handler = MasterReqHandler::new(backend).unwrap();
392 
393         assert!(handler.get_tx_raw_fd() >= 0);
394         assert!(handler.as_raw_fd() >= 0);
395         handler.check_state().unwrap();
396 
397         assert_eq!(handler.error, None);
398         handler.set_failed(libc::EAGAIN);
399         assert_eq!(handler.error, Some(libc::EAGAIN));
400         handler.check_state().unwrap_err();
401     }
402 
403     #[cfg(feature = "vhost-user-slave")]
404     #[test]
test_master_slave_req_handler()405     fn test_master_slave_req_handler() {
406         let backend = Arc::new(Mutex::new(MockMasterReqHandler {}));
407         let mut handler = MasterReqHandler::new(backend).unwrap();
408 
409         // SAFETY: Safe because `handler` contains valid fds, and we are
410         // checking if `dup` returns a valid fd.
411         let fd = unsafe { libc::dup(handler.get_tx_raw_fd()) };
412         if fd < 0 {
413             panic!("failed to duplicated tx fd!");
414         }
415         // SAFETY: Safe because we checked if fd is valid.
416         let stream = unsafe { UnixStream::from_raw_fd(fd) };
417         let slave = Slave::from_stream(stream);
418 
419         std::thread::spawn(move || {
420             let res = handler.handle_request().unwrap();
421             assert_eq!(res, 0);
422             handler.handle_request().unwrap_err();
423         });
424 
425         slave
426             .fs_slave_map(&VhostUserFSSlaveMsg::default(), &fd)
427             .unwrap();
428         // When REPLY_ACK has not been negotiated, the master has no way to detect failure from
429         // slave side.
430         slave
431             .fs_slave_unmap(&VhostUserFSSlaveMsg::default())
432             .unwrap();
433     }
434 
435     #[cfg(feature = "vhost-user-slave")]
436     #[test]
test_master_slave_req_handler_with_ack()437     fn test_master_slave_req_handler_with_ack() {
438         let backend = Arc::new(Mutex::new(MockMasterReqHandler {}));
439         let mut handler = MasterReqHandler::new(backend).unwrap();
440         handler.set_reply_ack_flag(true);
441 
442         // SAFETY: Safe because `handler` contains valid fds, and we are
443         // checking if `dup` returns a valid fd.
444         let fd = unsafe { libc::dup(handler.get_tx_raw_fd()) };
445         if fd < 0 {
446             panic!("failed to duplicated tx fd!");
447         }
448         // SAFETY: Safe because we checked if fd is valid.
449         let stream = unsafe { UnixStream::from_raw_fd(fd) };
450         let slave = Slave::from_stream(stream);
451 
452         std::thread::spawn(move || {
453             let res = handler.handle_request().unwrap();
454             assert_eq!(res, 0);
455             handler.handle_request().unwrap_err();
456         });
457 
458         slave.set_reply_ack_flag(true);
459         slave
460             .fs_slave_map(&VhostUserFSSlaveMsg::default(), &fd)
461             .unwrap();
462         slave
463             .fs_slave_unmap(&VhostUserFSSlaveMsg::default())
464             .unwrap_err();
465     }
466 }
467