1 // Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved. 2 // SPDX-License-Identifier: Apache-2.0 3 4 use std::fs::File; 5 use std::mem; 6 use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}; 7 use std::os::unix::net::UnixStream; 8 use std::slice; 9 use std::sync::{Arc, Mutex}; 10 11 use vm_memory::ByteValued; 12 13 use super::connection::Endpoint; 14 use super::message::*; 15 use super::slave_req::Slave; 16 use super::{take_single_file, Error, Result}; 17 18 /// Services provided to the master by the slave with interior mutability. 19 /// 20 /// The [VhostUserSlaveReqHandler] trait defines the services provided to the master by the slave. 21 /// And the [VhostUserSlaveReqHandlerMut] trait is a helper mirroring [VhostUserSlaveReqHandler], 22 /// but without interior mutability. 23 /// The vhost-user specification defines a master communication channel, by which masters could 24 /// request services from slaves. The [VhostUserSlaveReqHandler] trait defines services provided by 25 /// slaves, and it's used both on the master side and slave side. 26 /// 27 /// - on the master side, a stub forwarder implementing [VhostUserSlaveReqHandler] will proxy 28 /// service requests to slaves. 29 /// - on the slave side, the [SlaveReqHandler] will forward service requests to a handler 30 /// implementing [VhostUserSlaveReqHandler]. 31 /// 32 /// The [VhostUserSlaveReqHandler] trait is design with interior mutability to improve performance 33 /// for multi-threading. 34 /// 35 /// [VhostUserSlaveReqHandler]: trait.VhostUserSlaveReqHandler.html 36 /// [VhostUserSlaveReqHandlerMut]: trait.VhostUserSlaveReqHandlerMut.html 37 /// [SlaveReqHandler]: struct.SlaveReqHandler.html 38 #[allow(missing_docs)] 39 pub trait VhostUserSlaveReqHandler { set_owner(&self) -> Result<()>40 fn set_owner(&self) -> Result<()>; reset_owner(&self) -> Result<()>41 fn reset_owner(&self) -> Result<()>; get_features(&self) -> Result<u64>42 fn get_features(&self) -> Result<u64>; set_features(&self, features: u64) -> Result<()>43 fn set_features(&self, features: u64) -> Result<()>; set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>44 fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>; set_vring_num(&self, index: u32, num: u32) -> Result<()>45 fn set_vring_num(&self, index: u32, num: u32) -> Result<()>; set_vring_addr( &self, index: u32, flags: VhostUserVringAddrFlags, descriptor: u64, used: u64, available: u64, log: u64, ) -> Result<()>46 fn set_vring_addr( 47 &self, 48 index: u32, 49 flags: VhostUserVringAddrFlags, 50 descriptor: u64, 51 used: u64, 52 available: u64, 53 log: u64, 54 ) -> Result<()>; set_vring_base(&self, index: u32, base: u32) -> Result<()>55 fn set_vring_base(&self, index: u32, base: u32) -> Result<()>; get_vring_base(&self, index: u32) -> Result<VhostUserVringState>56 fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState>; set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()>57 fn set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()>; set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()>58 fn set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()>; set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()>59 fn set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()>; 60 get_protocol_features(&self) -> Result<VhostUserProtocolFeatures>61 fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures>; set_protocol_features(&self, features: u64) -> Result<()>62 fn set_protocol_features(&self, features: u64) -> Result<()>; get_queue_num(&self) -> Result<u64>63 fn get_queue_num(&self) -> Result<u64>; set_vring_enable(&self, index: u32, enable: bool) -> Result<()>64 fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()>; get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>>65 fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>>; set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>66 fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>; set_slave_req_fd(&self, _slave: Slave)67 fn set_slave_req_fd(&self, _slave: Slave) {} get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)>68 fn get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)>; set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()>69 fn set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()>; get_max_mem_slots(&self) -> Result<u64>70 fn get_max_mem_slots(&self) -> Result<u64>; add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>71 fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>; remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()>72 fn remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()>; 73 } 74 75 /// Services provided to the master by the slave without interior mutability. 76 /// 77 /// This is a helper trait mirroring the [VhostUserSlaveReqHandler] trait. 78 #[allow(missing_docs)] 79 pub trait VhostUserSlaveReqHandlerMut { set_owner(&mut self) -> Result<()>80 fn set_owner(&mut self) -> Result<()>; reset_owner(&mut self) -> Result<()>81 fn reset_owner(&mut self) -> Result<()>; get_features(&mut self) -> Result<u64>82 fn get_features(&mut self) -> Result<u64>; set_features(&mut self, features: u64) -> Result<()>83 fn set_features(&mut self, features: u64) -> Result<()>; set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>84 fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>; set_vring_num(&mut self, index: u32, num: u32) -> Result<()>85 fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()>; set_vring_addr( &mut self, index: u32, flags: VhostUserVringAddrFlags, descriptor: u64, used: u64, available: u64, log: u64, ) -> Result<()>86 fn set_vring_addr( 87 &mut self, 88 index: u32, 89 flags: VhostUserVringAddrFlags, 90 descriptor: u64, 91 used: u64, 92 available: u64, 93 log: u64, 94 ) -> Result<()>; set_vring_base(&mut self, index: u32, base: u32) -> Result<()>95 fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()>; get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>96 fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>; set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()>97 fn set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()>; set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()>98 fn set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()>; set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()>99 fn set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()>; 100 get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>101 fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>; set_protocol_features(&mut self, features: u64) -> Result<()>102 fn set_protocol_features(&mut self, features: u64) -> Result<()>; get_queue_num(&mut self) -> Result<u64>103 fn get_queue_num(&mut self) -> Result<u64>; set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()>104 fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()>; get_config( &mut self, offset: u32, size: u32, flags: VhostUserConfigFlags, ) -> Result<Vec<u8>>105 fn get_config( 106 &mut self, 107 offset: u32, 108 size: u32, 109 flags: VhostUserConfigFlags, 110 ) -> Result<Vec<u8>>; set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>111 fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>; set_slave_req_fd(&mut self, _slave: Slave)112 fn set_slave_req_fd(&mut self, _slave: Slave) {} get_inflight_fd( &mut self, inflight: &VhostUserInflight, ) -> Result<(VhostUserInflight, File)>113 fn get_inflight_fd( 114 &mut self, 115 inflight: &VhostUserInflight, 116 ) -> Result<(VhostUserInflight, File)>; set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()>117 fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()>; get_max_mem_slots(&mut self) -> Result<u64>118 fn get_max_mem_slots(&mut self) -> Result<u64>; add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>119 fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>; remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>120 fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>; 121 } 122 123 impl<T: VhostUserSlaveReqHandlerMut> VhostUserSlaveReqHandler for Mutex<T> { set_owner(&self) -> Result<()>124 fn set_owner(&self) -> Result<()> { 125 self.lock().unwrap().set_owner() 126 } 127 reset_owner(&self) -> Result<()>128 fn reset_owner(&self) -> Result<()> { 129 self.lock().unwrap().reset_owner() 130 } 131 get_features(&self) -> Result<u64>132 fn get_features(&self) -> Result<u64> { 133 self.lock().unwrap().get_features() 134 } 135 set_features(&self, features: u64) -> Result<()>136 fn set_features(&self, features: u64) -> Result<()> { 137 self.lock().unwrap().set_features(features) 138 } 139 set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>140 fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()> { 141 self.lock().unwrap().set_mem_table(ctx, files) 142 } 143 set_vring_num(&self, index: u32, num: u32) -> Result<()>144 fn set_vring_num(&self, index: u32, num: u32) -> Result<()> { 145 self.lock().unwrap().set_vring_num(index, num) 146 } 147 set_vring_addr( &self, index: u32, flags: VhostUserVringAddrFlags, descriptor: u64, used: u64, available: u64, log: u64, ) -> Result<()>148 fn set_vring_addr( 149 &self, 150 index: u32, 151 flags: VhostUserVringAddrFlags, 152 descriptor: u64, 153 used: u64, 154 available: u64, 155 log: u64, 156 ) -> Result<()> { 157 self.lock() 158 .unwrap() 159 .set_vring_addr(index, flags, descriptor, used, available, log) 160 } 161 set_vring_base(&self, index: u32, base: u32) -> Result<()>162 fn set_vring_base(&self, index: u32, base: u32) -> Result<()> { 163 self.lock().unwrap().set_vring_base(index, base) 164 } 165 get_vring_base(&self, index: u32) -> Result<VhostUserVringState>166 fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState> { 167 self.lock().unwrap().get_vring_base(index) 168 } 169 set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()>170 fn set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()> { 171 self.lock().unwrap().set_vring_kick(index, fd) 172 } 173 set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()>174 fn set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()> { 175 self.lock().unwrap().set_vring_call(index, fd) 176 } 177 set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()>178 fn set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()> { 179 self.lock().unwrap().set_vring_err(index, fd) 180 } 181 get_protocol_features(&self) -> Result<VhostUserProtocolFeatures>182 fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures> { 183 self.lock().unwrap().get_protocol_features() 184 } 185 set_protocol_features(&self, features: u64) -> Result<()>186 fn set_protocol_features(&self, features: u64) -> Result<()> { 187 self.lock().unwrap().set_protocol_features(features) 188 } 189 get_queue_num(&self) -> Result<u64>190 fn get_queue_num(&self) -> Result<u64> { 191 self.lock().unwrap().get_queue_num() 192 } 193 set_vring_enable(&self, index: u32, enable: bool) -> Result<()>194 fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()> { 195 self.lock().unwrap().set_vring_enable(index, enable) 196 } 197 get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>>198 fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>> { 199 self.lock().unwrap().get_config(offset, size, flags) 200 } 201 set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>202 fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()> { 203 self.lock().unwrap().set_config(offset, buf, flags) 204 } 205 set_slave_req_fd(&self, slave: Slave)206 fn set_slave_req_fd(&self, slave: Slave) { 207 self.lock().unwrap().set_slave_req_fd(slave) 208 } 209 get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)>210 fn get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)> { 211 self.lock().unwrap().get_inflight_fd(inflight) 212 } 213 set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()>214 fn set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()> { 215 self.lock().unwrap().set_inflight_fd(inflight, file) 216 } 217 get_max_mem_slots(&self) -> Result<u64>218 fn get_max_mem_slots(&self) -> Result<u64> { 219 self.lock().unwrap().get_max_mem_slots() 220 } 221 add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>222 fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()> { 223 self.lock().unwrap().add_mem_region(region, fd) 224 } 225 remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()>226 fn remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()> { 227 self.lock().unwrap().remove_mem_region(region) 228 } 229 } 230 231 /// Server to handle service requests from masters from the master communication channel. 232 /// 233 /// The [SlaveReqHandler] acts as a server on the slave side, to handle service requests from 234 /// masters on the master communication channel. It's actually a proxy invoking the registered 235 /// handler implementing [VhostUserSlaveReqHandler] to do the real work. 236 /// 237 /// The lifetime of the SlaveReqHandler object should be the same as the underline Unix Domain 238 /// Socket, so it gets simpler to recover from disconnect. 239 /// 240 /// [VhostUserSlaveReqHandler]: trait.VhostUserSlaveReqHandler.html 241 /// [SlaveReqHandler]: struct.SlaveReqHandler.html 242 pub struct SlaveReqHandler<S: VhostUserSlaveReqHandler> { 243 // underlying Unix domain socket for communication 244 main_sock: Endpoint<MasterReq>, 245 // the vhost-user backend device object 246 backend: Arc<S>, 247 248 virtio_features: u64, 249 acked_virtio_features: u64, 250 protocol_features: VhostUserProtocolFeatures, 251 acked_protocol_features: u64, 252 253 // sending ack for messages without payload 254 reply_ack_enabled: bool, 255 // whether the endpoint has encountered any failure 256 error: Option<i32>, 257 } 258 259 impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> { 260 /// Create a vhost-user slave endpoint. new(main_sock: Endpoint<MasterReq>, backend: Arc<S>) -> Self261 pub(super) fn new(main_sock: Endpoint<MasterReq>, backend: Arc<S>) -> Self { 262 SlaveReqHandler { 263 main_sock, 264 backend, 265 virtio_features: 0, 266 acked_virtio_features: 0, 267 protocol_features: VhostUserProtocolFeatures::empty(), 268 acked_protocol_features: 0, 269 reply_ack_enabled: false, 270 error: None, 271 } 272 } 273 check_feature(&self, feat: VhostUserVirtioFeatures) -> Result<()>274 fn check_feature(&self, feat: VhostUserVirtioFeatures) -> Result<()> { 275 if self.acked_virtio_features & feat.bits() != 0 { 276 Ok(()) 277 } else { 278 Err(Error::InactiveFeature(feat)) 279 } 280 } 281 check_proto_feature(&self, feat: VhostUserProtocolFeatures) -> Result<()>282 fn check_proto_feature(&self, feat: VhostUserProtocolFeatures) -> Result<()> { 283 if self.acked_protocol_features & feat.bits() != 0 { 284 Ok(()) 285 } else { 286 Err(Error::InactiveOperation(feat)) 287 } 288 } 289 290 /// Create a vhost-user slave endpoint from a connected socket. from_stream(socket: UnixStream, backend: Arc<S>) -> Self291 pub fn from_stream(socket: UnixStream, backend: Arc<S>) -> Self { 292 Self::new(Endpoint::from_stream(socket), backend) 293 } 294 295 /// Create a new vhost-user slave endpoint. 296 /// 297 /// # Arguments 298 /// * - `path` - path of Unix domain socket listener to connect to 299 /// * - `backend` - handler for requests from the master to the slave connect(path: &str, backend: Arc<S>) -> Result<Self>300 pub fn connect(path: &str, backend: Arc<S>) -> Result<Self> { 301 Ok(Self::new(Endpoint::<MasterReq>::connect(path)?, backend)) 302 } 303 304 /// Mark endpoint as failed with specified error code. set_failed(&mut self, error: i32)305 pub fn set_failed(&mut self, error: i32) { 306 self.error = Some(error); 307 } 308 309 /// Main entrance to server slave request from the slave communication channel. 310 /// 311 /// Receive and handle one incoming request message from the master. The caller needs to: 312 /// - serialize calls to this function 313 /// - decide what to do when error happens 314 /// - optional recover from failure handle_request(&mut self) -> Result<()>315 pub fn handle_request(&mut self) -> Result<()> { 316 // Return error if the endpoint is already in failed state. 317 self.check_state()?; 318 319 // The underlying communication channel is a Unix domain socket in 320 // stream mode, and recvmsg() is a little tricky here. To successfully 321 // receive attached file descriptors, we need to receive messages and 322 // corresponding attached file descriptors in this way: 323 // . recv messsage header and optional attached file 324 // . validate message header 325 // . recv optional message body and payload according size field in 326 // message header 327 // . validate message body and optional payload 328 let (hdr, files) = self.main_sock.recv_header()?; 329 self.check_attached_files(&hdr, &files)?; 330 331 let (size, buf) = match hdr.get_size() { 332 0 => (0, vec![0u8; 0]), 333 len => { 334 let (size2, rbuf) = self.main_sock.recv_data(len as usize)?; 335 if size2 != len as usize { 336 return Err(Error::InvalidMessage); 337 } 338 (size2, rbuf) 339 } 340 }; 341 342 match hdr.get_code() { 343 Ok(MasterReq::SET_OWNER) => { 344 self.check_request_size(&hdr, size, 0)?; 345 let res = self.backend.set_owner(); 346 self.send_ack_message(&hdr, res)?; 347 } 348 Ok(MasterReq::RESET_OWNER) => { 349 self.check_request_size(&hdr, size, 0)?; 350 let res = self.backend.reset_owner(); 351 self.send_ack_message(&hdr, res)?; 352 } 353 Ok(MasterReq::GET_FEATURES) => { 354 self.check_request_size(&hdr, size, 0)?; 355 let features = self.backend.get_features()?; 356 let msg = VhostUserU64::new(features); 357 self.send_reply_message(&hdr, &msg)?; 358 self.virtio_features = features; 359 self.update_reply_ack_flag(); 360 } 361 Ok(MasterReq::SET_FEATURES) => { 362 let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?; 363 let res = self.backend.set_features(msg.value); 364 self.acked_virtio_features = msg.value; 365 self.update_reply_ack_flag(); 366 self.send_ack_message(&hdr, res)?; 367 } 368 Ok(MasterReq::SET_MEM_TABLE) => { 369 let res = self.set_mem_table(&hdr, size, &buf, files); 370 self.send_ack_message(&hdr, res)?; 371 } 372 Ok(MasterReq::SET_VRING_NUM) => { 373 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?; 374 let res = self.backend.set_vring_num(msg.index, msg.num); 375 self.send_ack_message(&hdr, res)?; 376 } 377 Ok(MasterReq::SET_VRING_ADDR) => { 378 let msg = self.extract_request_body::<VhostUserVringAddr>(&hdr, size, &buf)?; 379 let flags = match VhostUserVringAddrFlags::from_bits(msg.flags) { 380 Some(val) => val, 381 None => return Err(Error::InvalidMessage), 382 }; 383 let res = self.backend.set_vring_addr( 384 msg.index, 385 flags, 386 msg.descriptor, 387 msg.used, 388 msg.available, 389 msg.log, 390 ); 391 self.send_ack_message(&hdr, res)?; 392 } 393 Ok(MasterReq::SET_VRING_BASE) => { 394 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?; 395 let res = self.backend.set_vring_base(msg.index, msg.num); 396 self.send_ack_message(&hdr, res)?; 397 } 398 Ok(MasterReq::GET_VRING_BASE) => { 399 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?; 400 let reply = self.backend.get_vring_base(msg.index)?; 401 self.send_reply_message(&hdr, &reply)?; 402 } 403 Ok(MasterReq::SET_VRING_CALL) => { 404 self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?; 405 let (index, file) = self.handle_vring_fd_request(&buf, files)?; 406 let res = self.backend.set_vring_call(index, file); 407 self.send_ack_message(&hdr, res)?; 408 } 409 Ok(MasterReq::SET_VRING_KICK) => { 410 self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?; 411 let (index, file) = self.handle_vring_fd_request(&buf, files)?; 412 let res = self.backend.set_vring_kick(index, file); 413 self.send_ack_message(&hdr, res)?; 414 } 415 Ok(MasterReq::SET_VRING_ERR) => { 416 self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?; 417 let (index, file) = self.handle_vring_fd_request(&buf, files)?; 418 let res = self.backend.set_vring_err(index, file); 419 self.send_ack_message(&hdr, res)?; 420 } 421 Ok(MasterReq::GET_PROTOCOL_FEATURES) => { 422 self.check_request_size(&hdr, size, 0)?; 423 let features = self.backend.get_protocol_features()?; 424 425 // Enable the `XEN_MMAP` protocol feature for backends if xen feature is enabled. 426 #[cfg(feature = "xen")] 427 let features = features | VhostUserProtocolFeatures::XEN_MMAP; 428 429 let msg = VhostUserU64::new(features.bits()); 430 self.send_reply_message(&hdr, &msg)?; 431 self.protocol_features = features; 432 self.update_reply_ack_flag(); 433 } 434 Ok(MasterReq::SET_PROTOCOL_FEATURES) => { 435 let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?; 436 let res = self.backend.set_protocol_features(msg.value); 437 self.acked_protocol_features = msg.value; 438 self.update_reply_ack_flag(); 439 self.send_ack_message(&hdr, res)?; 440 441 #[cfg(feature = "xen")] 442 self.check_proto_feature(VhostUserProtocolFeatures::XEN_MMAP)?; 443 } 444 Ok(MasterReq::GET_QUEUE_NUM) => { 445 self.check_proto_feature(VhostUserProtocolFeatures::MQ)?; 446 self.check_request_size(&hdr, size, 0)?; 447 let num = self.backend.get_queue_num()?; 448 let msg = VhostUserU64::new(num); 449 self.send_reply_message(&hdr, &msg)?; 450 } 451 Ok(MasterReq::SET_VRING_ENABLE) => { 452 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?; 453 self.check_feature(VhostUserVirtioFeatures::PROTOCOL_FEATURES)?; 454 let enable = match msg.num { 455 1 => true, 456 0 => false, 457 _ => return Err(Error::InvalidParam), 458 }; 459 460 let res = self.backend.set_vring_enable(msg.index, enable); 461 self.send_ack_message(&hdr, res)?; 462 } 463 Ok(MasterReq::GET_CONFIG) => { 464 self.check_proto_feature(VhostUserProtocolFeatures::CONFIG)?; 465 self.check_request_size(&hdr, size, hdr.get_size() as usize)?; 466 self.get_config(&hdr, &buf)?; 467 } 468 Ok(MasterReq::SET_CONFIG) => { 469 self.check_proto_feature(VhostUserProtocolFeatures::CONFIG)?; 470 self.check_request_size(&hdr, size, hdr.get_size() as usize)?; 471 let res = self.set_config(size, &buf); 472 self.send_ack_message(&hdr, res)?; 473 } 474 Ok(MasterReq::SET_SLAVE_REQ_FD) => { 475 self.check_proto_feature(VhostUserProtocolFeatures::SLAVE_REQ)?; 476 self.check_request_size(&hdr, size, hdr.get_size() as usize)?; 477 let res = self.set_slave_req_fd(files); 478 self.send_ack_message(&hdr, res)?; 479 } 480 Ok(MasterReq::GET_INFLIGHT_FD) => { 481 self.check_proto_feature(VhostUserProtocolFeatures::INFLIGHT_SHMFD)?; 482 483 let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?; 484 let (inflight, file) = self.backend.get_inflight_fd(&msg)?; 485 let reply_hdr = self.new_reply_header::<VhostUserInflight>(&hdr, 0)?; 486 self.main_sock 487 .send_message(&reply_hdr, &inflight, Some(&[file.as_raw_fd()]))?; 488 } 489 Ok(MasterReq::SET_INFLIGHT_FD) => { 490 self.check_proto_feature(VhostUserProtocolFeatures::INFLIGHT_SHMFD)?; 491 let file = take_single_file(files).ok_or(Error::IncorrectFds)?; 492 let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?; 493 let res = self.backend.set_inflight_fd(&msg, file); 494 self.send_ack_message(&hdr, res)?; 495 } 496 Ok(MasterReq::GET_MAX_MEM_SLOTS) => { 497 self.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?; 498 self.check_request_size(&hdr, size, 0)?; 499 let num = self.backend.get_max_mem_slots()?; 500 let msg = VhostUserU64::new(num); 501 self.send_reply_message(&hdr, &msg)?; 502 } 503 Ok(MasterReq::ADD_MEM_REG) => { 504 self.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?; 505 let mut files = files.ok_or(Error::InvalidParam)?; 506 if files.len() != 1 { 507 return Err(Error::InvalidParam); 508 } 509 let msg = 510 self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?; 511 let res = self.backend.add_mem_region(&msg, files.swap_remove(0)); 512 self.send_ack_message(&hdr, res)?; 513 } 514 Ok(MasterReq::REM_MEM_REG) => { 515 self.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?; 516 517 let msg = 518 self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?; 519 let res = self.backend.remove_mem_region(&msg); 520 self.send_ack_message(&hdr, res)?; 521 } 522 _ => { 523 return Err(Error::InvalidMessage); 524 } 525 } 526 Ok(()) 527 } 528 set_mem_table( &mut self, hdr: &VhostUserMsgHeader<MasterReq>, size: usize, buf: &[u8], files: Option<Vec<File>>, ) -> Result<()>529 fn set_mem_table( 530 &mut self, 531 hdr: &VhostUserMsgHeader<MasterReq>, 532 size: usize, 533 buf: &[u8], 534 files: Option<Vec<File>>, 535 ) -> Result<()> { 536 self.check_request_size(hdr, size, hdr.get_size() as usize)?; 537 538 // check message size is consistent 539 let hdrsize = mem::size_of::<VhostUserMemory>(); 540 if size < hdrsize { 541 return Err(Error::InvalidMessage); 542 } 543 // SAFETY: Safe because we checked that `buf` size is at least that of 544 // VhostUserMemory. 545 let msg = unsafe { &*(buf.as_ptr() as *const VhostUserMemory) }; 546 if !msg.is_valid() { 547 return Err(Error::InvalidMessage); 548 } 549 if size != hdrsize + msg.num_regions as usize * mem::size_of::<VhostUserMemoryRegion>() { 550 return Err(Error::InvalidMessage); 551 } 552 553 // validate number of fds matching number of memory regions 554 let files = files.ok_or(Error::InvalidMessage)?; 555 if files.len() != msg.num_regions as usize { 556 return Err(Error::InvalidMessage); 557 } 558 559 // Validate memory regions 560 // 561 // SAFETY: Safe because we checked that `buf` size is equal to that of 562 // VhostUserMemory, plus `msg.num_regions` elements of VhostUserMemoryRegion. 563 let regions = unsafe { 564 slice::from_raw_parts( 565 buf.as_ptr().add(hdrsize) as *const VhostUserMemoryRegion, 566 msg.num_regions as usize, 567 ) 568 }; 569 for region in regions.iter() { 570 if !region.is_valid() { 571 return Err(Error::InvalidMessage); 572 } 573 } 574 575 self.backend.set_mem_table(regions, files) 576 } 577 get_config(&mut self, hdr: &VhostUserMsgHeader<MasterReq>, buf: &[u8]) -> Result<()>578 fn get_config(&mut self, hdr: &VhostUserMsgHeader<MasterReq>, buf: &[u8]) -> Result<()> { 579 let payload_offset = mem::size_of::<VhostUserConfig>(); 580 if buf.len() > MAX_MSG_SIZE || buf.len() < payload_offset { 581 return Err(Error::InvalidMessage); 582 } 583 // SAFETY: Safe because we checked that `buf` size is at least that of VhostUserConfig. 584 let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserConfig) }; 585 if !msg.is_valid() { 586 return Err(Error::InvalidMessage); 587 } 588 if buf.len() - payload_offset != msg.size as usize { 589 return Err(Error::InvalidMessage); 590 } 591 let flags = match VhostUserConfigFlags::from_bits(msg.flags) { 592 Some(val) => val, 593 None => return Err(Error::InvalidMessage), 594 }; 595 let res = self.backend.get_config(msg.offset, msg.size, flags); 596 597 // vhost-user slave's payload size MUST match master's request 598 // on success, uses zero length of payload to indicate an error 599 // to vhost-user master. 600 match res { 601 Ok(ref buf) if buf.len() == msg.size as usize => { 602 let reply = VhostUserConfig::new(msg.offset, buf.len() as u32, flags); 603 self.send_reply_with_payload(hdr, &reply, buf.as_slice())?; 604 } 605 Ok(_) => { 606 let reply = VhostUserConfig::new(msg.offset, 0, flags); 607 self.send_reply_message(hdr, &reply)?; 608 } 609 Err(_) => { 610 let reply = VhostUserConfig::new(msg.offset, 0, flags); 611 self.send_reply_message(hdr, &reply)?; 612 } 613 } 614 Ok(()) 615 } 616 set_config(&mut self, size: usize, buf: &[u8]) -> Result<()>617 fn set_config(&mut self, size: usize, buf: &[u8]) -> Result<()> { 618 if size > MAX_MSG_SIZE || size < mem::size_of::<VhostUserConfig>() { 619 return Err(Error::InvalidMessage); 620 } 621 // SAFETY: Safe because we checked that `buf` size is at least that of VhostUserConfig. 622 let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserConfig) }; 623 if !msg.is_valid() { 624 return Err(Error::InvalidMessage); 625 } 626 if size - mem::size_of::<VhostUserConfig>() != msg.size as usize { 627 return Err(Error::InvalidMessage); 628 } 629 let flags = VhostUserConfigFlags::from_bits(msg.flags).ok_or(Error::InvalidMessage)?; 630 631 self.backend 632 .set_config(msg.offset, &buf[mem::size_of::<VhostUserConfig>()..], flags) 633 } 634 set_slave_req_fd(&mut self, files: Option<Vec<File>>) -> Result<()>635 fn set_slave_req_fd(&mut self, files: Option<Vec<File>>) -> Result<()> { 636 let file = take_single_file(files).ok_or(Error::InvalidMessage)?; 637 // SAFETY: Safe because we have ownership of the files that were 638 // checked when received. We have to trust that they are Unix sockets 639 // since we have no way to check this. If not, it will fail later. 640 let sock = unsafe { UnixStream::from_raw_fd(file.into_raw_fd()) }; 641 let slave = Slave::from_stream(sock); 642 self.backend.set_slave_req_fd(slave); 643 Ok(()) 644 } 645 handle_vring_fd_request( &mut self, buf: &[u8], files: Option<Vec<File>>, ) -> Result<(u8, Option<File>)>646 fn handle_vring_fd_request( 647 &mut self, 648 buf: &[u8], 649 files: Option<Vec<File>>, 650 ) -> Result<(u8, Option<File>)> { 651 if buf.len() > MAX_MSG_SIZE || buf.len() < mem::size_of::<VhostUserU64>() { 652 return Err(Error::InvalidMessage); 653 } 654 // SAFETY: Safe because we checked that `buf` size is at least that of VhostUserU64. 655 let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserU64) }; 656 if !msg.is_valid() { 657 return Err(Error::InvalidMessage); 658 } 659 660 // Bits (0-7) of the payload contain the vring index. Bit 8 is the 661 // invalid FD flag. This bit is set when there is no file descriptor 662 // in the ancillary data. This signals that polling will be used 663 // instead of waiting for the call. 664 // If Bit 8 is unset, the data must contain a file descriptor. 665 let has_fd = (msg.value & 0x100u64) == 0; 666 667 let file = take_single_file(files); 668 669 if has_fd && file.is_none() || !has_fd && file.is_some() { 670 return Err(Error::InvalidMessage); 671 } 672 673 Ok((msg.value as u8, file)) 674 } 675 check_state(&self) -> Result<()>676 fn check_state(&self) -> Result<()> { 677 match self.error { 678 Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))), 679 None => Ok(()), 680 } 681 } 682 check_request_size( &self, hdr: &VhostUserMsgHeader<MasterReq>, size: usize, expected: usize, ) -> Result<()>683 fn check_request_size( 684 &self, 685 hdr: &VhostUserMsgHeader<MasterReq>, 686 size: usize, 687 expected: usize, 688 ) -> Result<()> { 689 if hdr.get_size() as usize != expected 690 || hdr.is_reply() 691 || hdr.get_version() != 0x1 692 || size != expected 693 { 694 return Err(Error::InvalidMessage); 695 } 696 Ok(()) 697 } 698 check_attached_files( &self, hdr: &VhostUserMsgHeader<MasterReq>, files: &Option<Vec<File>>, ) -> Result<()>699 fn check_attached_files( 700 &self, 701 hdr: &VhostUserMsgHeader<MasterReq>, 702 files: &Option<Vec<File>>, 703 ) -> Result<()> { 704 match hdr.get_code() { 705 Ok( 706 MasterReq::SET_MEM_TABLE 707 | MasterReq::SET_VRING_CALL 708 | MasterReq::SET_VRING_KICK 709 | MasterReq::SET_VRING_ERR 710 | MasterReq::SET_LOG_BASE 711 | MasterReq::SET_LOG_FD 712 | MasterReq::SET_SLAVE_REQ_FD 713 | MasterReq::SET_INFLIGHT_FD 714 | MasterReq::ADD_MEM_REG, 715 ) => Ok(()), 716 _ if files.is_some() => Err(Error::InvalidMessage), 717 _ => Ok(()), 718 } 719 } 720 extract_request_body<T: Sized + VhostUserMsgValidator>( &self, hdr: &VhostUserMsgHeader<MasterReq>, size: usize, buf: &[u8], ) -> Result<T>721 fn extract_request_body<T: Sized + VhostUserMsgValidator>( 722 &self, 723 hdr: &VhostUserMsgHeader<MasterReq>, 724 size: usize, 725 buf: &[u8], 726 ) -> Result<T> { 727 self.check_request_size(hdr, size, mem::size_of::<T>())?; 728 // SAFETY: Safe because we checked that `buf` size is equal to T size. 729 let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const T) }; 730 if !msg.is_valid() { 731 return Err(Error::InvalidMessage); 732 } 733 Ok(msg) 734 } 735 update_reply_ack_flag(&mut self)736 fn update_reply_ack_flag(&mut self) { 737 let vflag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits(); 738 let pflag = VhostUserProtocolFeatures::REPLY_ACK; 739 if (self.virtio_features & vflag) != 0 740 && self.protocol_features.contains(pflag) 741 && (self.acked_protocol_features & pflag.bits()) != 0 742 { 743 self.reply_ack_enabled = true; 744 } else { 745 self.reply_ack_enabled = false; 746 } 747 } 748 new_reply_header<T: Sized>( &self, req: &VhostUserMsgHeader<MasterReq>, payload_size: usize, ) -> Result<VhostUserMsgHeader<MasterReq>>749 fn new_reply_header<T: Sized>( 750 &self, 751 req: &VhostUserMsgHeader<MasterReq>, 752 payload_size: usize, 753 ) -> Result<VhostUserMsgHeader<MasterReq>> { 754 if mem::size_of::<T>() > MAX_MSG_SIZE 755 || payload_size > MAX_MSG_SIZE 756 || mem::size_of::<T>() + payload_size > MAX_MSG_SIZE 757 { 758 return Err(Error::InvalidParam); 759 } 760 self.check_state()?; 761 Ok(VhostUserMsgHeader::new( 762 req.get_code()?, 763 VhostUserHeaderFlag::REPLY.bits(), 764 (mem::size_of::<T>() + payload_size) as u32, 765 )) 766 } 767 send_ack_message( &mut self, req: &VhostUserMsgHeader<MasterReq>, res: Result<()>, ) -> Result<()>768 fn send_ack_message( 769 &mut self, 770 req: &VhostUserMsgHeader<MasterReq>, 771 res: Result<()>, 772 ) -> Result<()> { 773 if self.reply_ack_enabled && req.is_need_reply() { 774 let hdr = self.new_reply_header::<VhostUserU64>(req, 0)?; 775 let val = match res { 776 Ok(_) => 0, 777 Err(_) => 1, 778 }; 779 let msg = VhostUserU64::new(val); 780 self.main_sock.send_message(&hdr, &msg, None)?; 781 } 782 res 783 } 784 send_reply_message<T: ByteValued>( &mut self, req: &VhostUserMsgHeader<MasterReq>, msg: &T, ) -> Result<()>785 fn send_reply_message<T: ByteValued>( 786 &mut self, 787 req: &VhostUserMsgHeader<MasterReq>, 788 msg: &T, 789 ) -> Result<()> { 790 let hdr = self.new_reply_header::<T>(req, 0)?; 791 self.main_sock.send_message(&hdr, msg, None)?; 792 Ok(()) 793 } 794 send_reply_with_payload<T: ByteValued>( &mut self, req: &VhostUserMsgHeader<MasterReq>, msg: &T, payload: &[u8], ) -> Result<()>795 fn send_reply_with_payload<T: ByteValued>( 796 &mut self, 797 req: &VhostUserMsgHeader<MasterReq>, 798 msg: &T, 799 payload: &[u8], 800 ) -> Result<()> { 801 let hdr = self.new_reply_header::<T>(req, payload.len())?; 802 self.main_sock 803 .send_message_with_payload(&hdr, msg, payload, None)?; 804 Ok(()) 805 } 806 } 807 808 impl<S: VhostUserSlaveReqHandler> AsRawFd for SlaveReqHandler<S> { as_raw_fd(&self) -> RawFd809 fn as_raw_fd(&self) -> RawFd { 810 self.main_sock.as_raw_fd() 811 } 812 } 813 814 #[cfg(test)] 815 mod tests { 816 use std::os::unix::io::AsRawFd; 817 818 use super::*; 819 use crate::vhost_user::dummy_slave::DummySlaveReqHandler; 820 821 #[test] test_slave_req_handler_new()822 fn test_slave_req_handler_new() { 823 let (p1, _p2) = UnixStream::pair().unwrap(); 824 let endpoint = Endpoint::<MasterReq>::from_stream(p1); 825 let backend = Arc::new(Mutex::new(DummySlaveReqHandler::new())); 826 let mut handler = SlaveReqHandler::new(endpoint, backend); 827 828 handler.check_state().unwrap(); 829 handler.set_failed(libc::EAGAIN); 830 handler.check_state().unwrap_err(); 831 assert!(handler.as_raw_fd() >= 0); 832 } 833 } 834