1 // Copyright (C) 2019-2021 Alibaba Cloud. All rights reserved. 2 // SPDX-License-Identifier: Apache-2.0 3 4 use std::fs::File; 5 use std::mem; 6 use std::os::unix::io::{AsRawFd, RawFd}; 7 use std::os::unix::net::UnixStream; 8 use std::sync::{Arc, Mutex}; 9 10 use super::connection::Endpoint; 11 use super::message::*; 12 use super::{Error, HandlerResult, Result}; 13 14 /// Define services provided by masters for the slave communication channel. 15 /// 16 /// The vhost-user specification defines a slave communication channel, by which slaves could 17 /// request services from masters. The [VhostUserMasterReqHandler] trait defines services provided 18 /// by masters, and it's used both on the master side and slave side. 19 /// - on the slave side, a stub forwarder implementing [VhostUserMasterReqHandler] will proxy 20 /// service requests to masters. The [Slave] is an example stub forwarder. 21 /// - on the master side, the [MasterReqHandler] will forward service requests to a handler 22 /// implementing [VhostUserMasterReqHandler]. 23 /// 24 /// The [VhostUserMasterReqHandler] trait is design with interior mutability to improve performance 25 /// for multi-threading. 26 /// 27 /// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html 28 /// [MasterReqHandler]: struct.MasterReqHandler.html 29 /// [Slave]: struct.Slave.html 30 pub trait VhostUserMasterReqHandler { 31 /// Handle device configuration change notifications. handle_config_change(&self) -> HandlerResult<u64>32 fn handle_config_change(&self) -> HandlerResult<u64> { 33 Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) 34 } 35 36 /// Handle virtio-fs map file requests. fs_slave_map(&self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd) -> HandlerResult<u64>37 fn fs_slave_map(&self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd) -> HandlerResult<u64> { 38 Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) 39 } 40 41 /// Handle virtio-fs unmap file requests. fs_slave_unmap(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>42 fn fs_slave_unmap(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { 43 Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) 44 } 45 46 /// Handle virtio-fs sync file requests. fs_slave_sync(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>47 fn fs_slave_sync(&self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { 48 Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) 49 } 50 51 /// Handle virtio-fs file IO requests. fs_slave_io(&self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd) -> HandlerResult<u64>52 fn fs_slave_io(&self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd) -> HandlerResult<u64> { 53 Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) 54 } 55 56 // fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb); 57 // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: &dyn AsRawFd); 58 } 59 60 /// A helper trait mirroring [VhostUserMasterReqHandler] but without interior mutability. 61 /// 62 /// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html 63 pub trait VhostUserMasterReqHandlerMut { 64 /// Handle device configuration change notifications. handle_config_change(&mut self) -> HandlerResult<u64>65 fn handle_config_change(&mut self) -> HandlerResult<u64> { 66 Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) 67 } 68 69 /// Handle virtio-fs map file requests. fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd) -> HandlerResult<u64>70 fn fs_slave_map(&mut self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd) -> HandlerResult<u64> { 71 Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) 72 } 73 74 /// Handle virtio-fs unmap file requests. fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>75 fn fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { 76 Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) 77 } 78 79 /// Handle virtio-fs sync file requests. fs_slave_sync(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>80 fn fs_slave_sync(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { 81 Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) 82 } 83 84 /// Handle virtio-fs file IO requests. fs_slave_io(&mut self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd) -> HandlerResult<u64>85 fn fs_slave_io(&mut self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd) -> HandlerResult<u64> { 86 Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) 87 } 88 89 // fn handle_iotlb_msg(&mut self, iotlb: VhostUserIotlb); 90 // fn handle_vring_host_notifier(&mut self, area: VhostUserVringArea, fd: RawFd); 91 } 92 93 impl<S: VhostUserMasterReqHandlerMut> VhostUserMasterReqHandler for Mutex<S> { handle_config_change(&self) -> HandlerResult<u64>94 fn handle_config_change(&self) -> HandlerResult<u64> { 95 self.lock().unwrap().handle_config_change() 96 } 97 fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: &dyn AsRawFd) -> HandlerResult<u64>98 fn fs_slave_map(&self, fs: &VhostUserFSSlaveMsg, fd: &dyn AsRawFd) -> HandlerResult<u64> { 99 self.lock().unwrap().fs_slave_map(fs, fd) 100 } 101 fs_slave_unmap(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>102 fn fs_slave_unmap(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { 103 self.lock().unwrap().fs_slave_unmap(fs) 104 } 105 fs_slave_sync(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>106 fn fs_slave_sync(&self, fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { 107 self.lock().unwrap().fs_slave_sync(fs) 108 } 109 fs_slave_io(&self, fs: &VhostUserFSSlaveMsg, fd: &dyn AsRawFd) -> HandlerResult<u64>110 fn fs_slave_io(&self, fs: &VhostUserFSSlaveMsg, fd: &dyn AsRawFd) -> HandlerResult<u64> { 111 self.lock().unwrap().fs_slave_io(fs, fd) 112 } 113 } 114 115 /// Server to handle service requests from slaves from the slave communication channel. 116 /// 117 /// The [MasterReqHandler] acts as a server on the master side, to handle service requests from 118 /// slaves on the slave communication channel. It's actually a proxy invoking the registered 119 /// handler implementing [VhostUserMasterReqHandler] to do the real work. 120 /// 121 /// [MasterReqHandler]: struct.MasterReqHandler.html 122 /// [VhostUserMasterReqHandler]: trait.VhostUserMasterReqHandler.html 123 pub struct MasterReqHandler<S: VhostUserMasterReqHandler> { 124 // underlying Unix domain socket for communication 125 sub_sock: Endpoint<SlaveReq>, 126 tx_sock: UnixStream, 127 // Protocol feature VHOST_USER_PROTOCOL_F_REPLY_ACK has been negotiated. 128 reply_ack_negotiated: bool, 129 // the VirtIO backend device object 130 backend: Arc<S>, 131 // whether the endpoint has encountered any failure 132 error: Option<i32>, 133 } 134 135 impl<S: VhostUserMasterReqHandler> MasterReqHandler<S> { 136 /// Create a server to handle service requests from slaves on the slave communication channel. 137 /// 138 /// This opens a pair of connected anonymous sockets to form the slave communication channel. 139 /// The socket fd returned by [Self::get_tx_raw_fd()] should be sent to the slave by 140 /// [VhostUserMaster::set_slave_request_fd()]. 141 /// 142 /// [Self::get_tx_raw_fd()]: struct.MasterReqHandler.html#method.get_tx_raw_fd 143 /// [VhostUserMaster::set_slave_request_fd()]: trait.VhostUserMaster.html#tymethod.set_slave_request_fd new(backend: Arc<S>) -> Result<Self>144 pub fn new(backend: Arc<S>) -> Result<Self> { 145 let (tx, rx) = UnixStream::pair().map_err(Error::SocketError)?; 146 147 Ok(MasterReqHandler { 148 sub_sock: Endpoint::<SlaveReq>::from_stream(rx), 149 tx_sock: tx, 150 reply_ack_negotiated: false, 151 backend, 152 error: None, 153 }) 154 } 155 156 /// Get the socket fd for the slave to communication with the master. 157 /// 158 /// The returned fd should be sent to the slave by [VhostUserMaster::set_slave_request_fd()]. 159 /// 160 /// [VhostUserMaster::set_slave_request_fd()]: trait.VhostUserMaster.html#tymethod.set_slave_request_fd get_tx_raw_fd(&self) -> RawFd161 pub fn get_tx_raw_fd(&self) -> RawFd { 162 self.tx_sock.as_raw_fd() 163 } 164 165 /// Set the negotiation state of the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature. 166 /// 167 /// When the `VHOST_USER_PROTOCOL_F_REPLY_ACK` protocol feature has been negotiated, 168 /// the "REPLY_ACK" flag will be set in the message header for every slave to master request 169 /// message. set_reply_ack_flag(&mut self, enable: bool)170 pub fn set_reply_ack_flag(&mut self, enable: bool) { 171 self.reply_ack_negotiated = enable; 172 } 173 174 /// Mark endpoint as failed or in normal state. set_failed(&mut self, error: i32)175 pub fn set_failed(&mut self, error: i32) { 176 if error == 0 { 177 self.error = None; 178 } else { 179 self.error = Some(error); 180 } 181 } 182 183 /// Main entrance to server slave request from the slave communication channel. 184 /// 185 /// The caller needs to: 186 /// - serialize calls to this function 187 /// - decide what to do when errer happens 188 /// - optional recover from failure handle_request(&mut self) -> Result<u64>189 pub fn handle_request(&mut self) -> Result<u64> { 190 // Return error if the endpoint is already in failed state. 191 self.check_state()?; 192 193 // The underlying communication channel is a Unix domain socket in 194 // stream mode, and recvmsg() is a little tricky here. To successfully 195 // receive attached file descriptors, we need to receive messages and 196 // corresponding attached file descriptors in this way: 197 // . recv messsage header and optional attached file 198 // . validate message header 199 // . recv optional message body and payload according size field in 200 // message header 201 // . validate message body and optional payload 202 let (hdr, files) = self.sub_sock.recv_header()?; 203 self.check_attached_files(&hdr, &files)?; 204 let (size, buf) = match hdr.get_size() { 205 0 => (0, vec![0u8; 0]), 206 len => { 207 if len as usize > MAX_MSG_SIZE { 208 return Err(Error::InvalidMessage); 209 } 210 let (size2, rbuf) = self.sub_sock.recv_data(len as usize)?; 211 if size2 != len as usize { 212 return Err(Error::InvalidMessage); 213 } 214 (size2, rbuf) 215 } 216 }; 217 218 let res = match hdr.get_code() { 219 Ok(SlaveReq::CONFIG_CHANGE_MSG) => { 220 self.check_msg_size(&hdr, size, 0)?; 221 self.backend 222 .handle_config_change() 223 .map_err(Error::ReqHandlerError) 224 } 225 Ok(SlaveReq::FS_MAP) => { 226 let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?; 227 // check_attached_files() has validated files 228 self.backend 229 .fs_slave_map(&msg, &files.unwrap()[0]) 230 .map_err(Error::ReqHandlerError) 231 } 232 Ok(SlaveReq::FS_UNMAP) => { 233 let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?; 234 self.backend 235 .fs_slave_unmap(&msg) 236 .map_err(Error::ReqHandlerError) 237 } 238 Ok(SlaveReq::FS_SYNC) => { 239 let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?; 240 self.backend 241 .fs_slave_sync(&msg) 242 .map_err(Error::ReqHandlerError) 243 } 244 Ok(SlaveReq::FS_IO) => { 245 let msg = self.extract_msg_body::<VhostUserFSSlaveMsg>(&hdr, size, &buf)?; 246 // check_attached_files() has validated files 247 self.backend 248 .fs_slave_io(&msg, &files.unwrap()[0]) 249 .map_err(Error::ReqHandlerError) 250 } 251 _ => Err(Error::InvalidMessage), 252 }; 253 254 self.send_ack_message(&hdr, &res)?; 255 256 res 257 } 258 check_state(&self) -> Result<()>259 fn check_state(&self) -> Result<()> { 260 match self.error { 261 Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))), 262 None => Ok(()), 263 } 264 } 265 check_msg_size( &self, hdr: &VhostUserMsgHeader<SlaveReq>, size: usize, expected: usize, ) -> Result<()>266 fn check_msg_size( 267 &self, 268 hdr: &VhostUserMsgHeader<SlaveReq>, 269 size: usize, 270 expected: usize, 271 ) -> Result<()> { 272 if hdr.get_size() as usize != expected 273 || hdr.is_reply() 274 || hdr.get_version() != 0x1 275 || size != expected 276 { 277 return Err(Error::InvalidMessage); 278 } 279 Ok(()) 280 } 281 check_attached_files( &self, hdr: &VhostUserMsgHeader<SlaveReq>, files: &Option<Vec<File>>, ) -> Result<()>282 fn check_attached_files( 283 &self, 284 hdr: &VhostUserMsgHeader<SlaveReq>, 285 files: &Option<Vec<File>>, 286 ) -> Result<()> { 287 match hdr.get_code() { 288 Ok(SlaveReq::FS_MAP | SlaveReq::FS_IO) => { 289 // Expect a single file is passed. 290 match files { 291 Some(files) if files.len() == 1 => Ok(()), 292 _ => Err(Error::InvalidMessage), 293 } 294 } 295 _ if files.is_some() => Err(Error::InvalidMessage), 296 _ => Ok(()), 297 } 298 } 299 extract_msg_body<T: Sized + VhostUserMsgValidator>( &self, hdr: &VhostUserMsgHeader<SlaveReq>, size: usize, buf: &[u8], ) -> Result<T>300 fn extract_msg_body<T: Sized + VhostUserMsgValidator>( 301 &self, 302 hdr: &VhostUserMsgHeader<SlaveReq>, 303 size: usize, 304 buf: &[u8], 305 ) -> Result<T> { 306 self.check_msg_size(hdr, size, mem::size_of::<T>())?; 307 // SAFETY: Safe because we checked that `buf` size is equal to T size. 308 let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const T) }; 309 if !msg.is_valid() { 310 return Err(Error::InvalidMessage); 311 } 312 Ok(msg) 313 } 314 new_reply_header<T: Sized>( &self, req: &VhostUserMsgHeader<SlaveReq>, ) -> Result<VhostUserMsgHeader<SlaveReq>>315 fn new_reply_header<T: Sized>( 316 &self, 317 req: &VhostUserMsgHeader<SlaveReq>, 318 ) -> Result<VhostUserMsgHeader<SlaveReq>> { 319 if mem::size_of::<T>() > MAX_MSG_SIZE { 320 return Err(Error::InvalidParam); 321 } 322 self.check_state()?; 323 Ok(VhostUserMsgHeader::new( 324 req.get_code()?, 325 VhostUserHeaderFlag::REPLY.bits(), 326 mem::size_of::<T>() as u32, 327 )) 328 } 329 send_ack_message( &mut self, req: &VhostUserMsgHeader<SlaveReq>, res: &Result<u64>, ) -> Result<()>330 fn send_ack_message( 331 &mut self, 332 req: &VhostUserMsgHeader<SlaveReq>, 333 res: &Result<u64>, 334 ) -> Result<()> { 335 if self.reply_ack_negotiated && req.is_need_reply() { 336 let hdr = self.new_reply_header::<VhostUserU64>(req)?; 337 let def_err = libc::EINVAL; 338 let val = match res { 339 Ok(n) => *n, 340 Err(e) => match e { 341 Error::ReqHandlerError(ioerr) => match ioerr.raw_os_error() { 342 Some(rawerr) => -rawerr as u64, 343 None => -def_err as u64, 344 }, 345 _ => -def_err as u64, 346 }, 347 }; 348 let msg = VhostUserU64::new(val); 349 self.sub_sock.send_message(&hdr, &msg, None)?; 350 } 351 Ok(()) 352 } 353 } 354 355 impl<S: VhostUserMasterReqHandler> AsRawFd for MasterReqHandler<S> { as_raw_fd(&self) -> RawFd356 fn as_raw_fd(&self) -> RawFd { 357 self.sub_sock.as_raw_fd() 358 } 359 } 360 361 #[cfg(test)] 362 mod tests { 363 use super::*; 364 365 #[cfg(feature = "vhost-user-slave")] 366 use crate::vhost_user::Slave; 367 #[cfg(feature = "vhost-user-slave")] 368 use std::os::unix::io::FromRawFd; 369 370 struct MockMasterReqHandler {} 371 372 impl VhostUserMasterReqHandlerMut for MockMasterReqHandler { 373 /// Handle virtio-fs map file requests from the slave. fs_slave_map( &mut self, _fs: &VhostUserFSSlaveMsg, _fd: &dyn AsRawFd, ) -> HandlerResult<u64>374 fn fs_slave_map( 375 &mut self, 376 _fs: &VhostUserFSSlaveMsg, 377 _fd: &dyn AsRawFd, 378 ) -> HandlerResult<u64> { 379 Ok(0) 380 } 381 382 /// Handle virtio-fs unmap file requests from the slave. fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64>383 fn fs_slave_unmap(&mut self, _fs: &VhostUserFSSlaveMsg) -> HandlerResult<u64> { 384 Err(std::io::Error::from_raw_os_error(libc::ENOSYS)) 385 } 386 } 387 388 #[test] test_new_master_req_handler()389 fn test_new_master_req_handler() { 390 let backend = Arc::new(Mutex::new(MockMasterReqHandler {})); 391 let mut handler = MasterReqHandler::new(backend).unwrap(); 392 393 assert!(handler.get_tx_raw_fd() >= 0); 394 assert!(handler.as_raw_fd() >= 0); 395 handler.check_state().unwrap(); 396 397 assert_eq!(handler.error, None); 398 handler.set_failed(libc::EAGAIN); 399 assert_eq!(handler.error, Some(libc::EAGAIN)); 400 handler.check_state().unwrap_err(); 401 } 402 403 #[cfg(feature = "vhost-user-slave")] 404 #[test] test_master_slave_req_handler()405 fn test_master_slave_req_handler() { 406 let backend = Arc::new(Mutex::new(MockMasterReqHandler {})); 407 let mut handler = MasterReqHandler::new(backend).unwrap(); 408 409 // SAFETY: Safe because `handler` contains valid fds, and we are 410 // checking if `dup` returns a valid fd. 411 let fd = unsafe { libc::dup(handler.get_tx_raw_fd()) }; 412 if fd < 0 { 413 panic!("failed to duplicated tx fd!"); 414 } 415 // SAFETY: Safe because we checked if fd is valid. 416 let stream = unsafe { UnixStream::from_raw_fd(fd) }; 417 let slave = Slave::from_stream(stream); 418 419 std::thread::spawn(move || { 420 let res = handler.handle_request().unwrap(); 421 assert_eq!(res, 0); 422 handler.handle_request().unwrap_err(); 423 }); 424 425 slave 426 .fs_slave_map(&VhostUserFSSlaveMsg::default(), &fd) 427 .unwrap(); 428 // When REPLY_ACK has not been negotiated, the master has no way to detect failure from 429 // slave side. 430 slave 431 .fs_slave_unmap(&VhostUserFSSlaveMsg::default()) 432 .unwrap(); 433 } 434 435 #[cfg(feature = "vhost-user-slave")] 436 #[test] test_master_slave_req_handler_with_ack()437 fn test_master_slave_req_handler_with_ack() { 438 let backend = Arc::new(Mutex::new(MockMasterReqHandler {})); 439 let mut handler = MasterReqHandler::new(backend).unwrap(); 440 handler.set_reply_ack_flag(true); 441 442 // SAFETY: Safe because `handler` contains valid fds, and we are 443 // checking if `dup` returns a valid fd. 444 let fd = unsafe { libc::dup(handler.get_tx_raw_fd()) }; 445 if fd < 0 { 446 panic!("failed to duplicated tx fd!"); 447 } 448 // SAFETY: Safe because we checked if fd is valid. 449 let stream = unsafe { UnixStream::from_raw_fd(fd) }; 450 let slave = Slave::from_stream(stream); 451 452 std::thread::spawn(move || { 453 let res = handler.handle_request().unwrap(); 454 assert_eq!(res, 0); 455 handler.handle_request().unwrap_err(); 456 }); 457 458 slave.set_reply_ack_flag(true); 459 slave 460 .fs_slave_map(&VhostUserFSSlaveMsg::default(), &fd) 461 .unwrap(); 462 slave 463 .fs_slave_unmap(&VhostUserFSSlaveMsg::default()) 464 .unwrap_err(); 465 } 466 } 467