1 // SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause 2 3 use std::{ 4 collections::{HashMap, HashSet, VecDeque}, 5 ops::Deref, 6 os::unix::{ 7 net::UnixStream, 8 prelude::{AsRawFd, RawFd}, 9 }, 10 sync::{Arc, RwLock}, 11 }; 12 13 use log::{info, warn}; 14 use virtio_vsock::packet::{VsockPacket, PKT_HEADER_SIZE}; 15 use vm_memory::bitmap::BitmapSlice; 16 17 use crate::{ 18 rxops::*, 19 vhu_vsock::{ 20 CidMap, ConnMapKey, Error, Result, VSOCK_HOST_CID, VSOCK_OP_REQUEST, VSOCK_OP_RST, 21 VSOCK_TYPE_STREAM, 22 }, 23 vhu_vsock_thread::VhostUserVsockThread, 24 vsock_conn::*, 25 }; 26 27 pub(crate) type RawPktsQ = VecDeque<RawVsockPacket>; 28 29 pub(crate) struct RawVsockPacket { 30 pub header: [u8; PKT_HEADER_SIZE], 31 pub data: Vec<u8>, 32 } 33 34 impl RawVsockPacket { from_vsock_packet<B: BitmapSlice>(pkt: &VsockPacket<B>) -> Result<Self>35 fn from_vsock_packet<B: BitmapSlice>(pkt: &VsockPacket<B>) -> Result<Self> { 36 let mut raw_pkt = Self { 37 header: [0; PKT_HEADER_SIZE], 38 data: vec![0; pkt.len() as usize], 39 }; 40 41 pkt.header_slice().copy_to(&mut raw_pkt.header); 42 if !pkt.is_empty() { 43 pkt.data_slice() 44 .ok_or(Error::PktBufMissing)? 45 .copy_to(raw_pkt.data.as_mut()); 46 } 47 48 Ok(raw_pkt) 49 } 50 } 51 52 pub(crate) struct VsockThreadBackend { 53 /// Map of ConnMapKey objects indexed by raw file descriptors. 54 pub listener_map: HashMap<RawFd, ConnMapKey>, 55 /// Map of vsock connection objects indexed by ConnMapKey objects. 56 pub conn_map: HashMap<ConnMapKey, VsockConnection<UnixStream>>, 57 /// Queue of ConnMapKey objects indicating pending rx operations. 58 pub backend_rxq: VecDeque<ConnMapKey>, 59 /// Map of host-side unix streams indexed by raw file descriptors. 60 pub stream_map: HashMap<i32, UnixStream>, 61 /// Host side socket for listening to new connections from the host. 62 host_socket_path: String, 63 /// epoll for registering new host-side connections. 64 epoll_fd: i32, 65 /// CID of the guest. 66 guest_cid: u64, 67 /// Set of allocated local ports. 68 pub local_port_set: HashSet<u32>, 69 tx_buffer_size: u32, 70 /// Maps the guest CID to the corresponding backend. Used for sibling VM communication. 71 pub cid_map: Arc<RwLock<CidMap>>, 72 /// Queue of raw vsock packets recieved from sibling VMs to be sent to the guest. 73 pub raw_pkts_queue: Arc<RwLock<RawPktsQ>>, 74 /// Set of groups assigned to the device which it is allowed to communicate with. 75 groups_set: Arc<RwLock<HashSet<String>>>, 76 } 77 78 impl VsockThreadBackend { 79 /// New instance of VsockThreadBackend. new( host_socket_path: String, epoll_fd: i32, guest_cid: u64, tx_buffer_size: u32, groups_set: Arc<RwLock<HashSet<String>>>, cid_map: Arc<RwLock<CidMap>>, ) -> Self80 pub fn new( 81 host_socket_path: String, 82 epoll_fd: i32, 83 guest_cid: u64, 84 tx_buffer_size: u32, 85 groups_set: Arc<RwLock<HashSet<String>>>, 86 cid_map: Arc<RwLock<CidMap>>, 87 ) -> Self { 88 Self { 89 listener_map: HashMap::new(), 90 conn_map: HashMap::new(), 91 backend_rxq: VecDeque::new(), 92 // Need this map to prevent connected stream from closing 93 // TODO: think of a better solution 94 stream_map: HashMap::new(), 95 host_socket_path, 96 epoll_fd, 97 guest_cid, 98 local_port_set: HashSet::new(), 99 tx_buffer_size, 100 cid_map, 101 raw_pkts_queue: Arc::new(RwLock::new(VecDeque::new())), 102 groups_set, 103 } 104 } 105 106 /// Checks if there are pending rx requests in the backend rxq. pending_rx(&self) -> bool107 pub fn pending_rx(&self) -> bool { 108 !self.backend_rxq.is_empty() 109 } 110 111 /// Checks if there are pending raw vsock packets to be sent to the guest. pending_raw_pkts(&self) -> bool112 pub fn pending_raw_pkts(&self) -> bool { 113 !self.raw_pkts_queue.read().unwrap().is_empty() 114 } 115 116 /// Deliver a vsock packet to the guest vsock driver. 117 /// 118 /// Returns: 119 /// - `Ok(())` if the packet was successfully filled in 120 /// - `Err(Error::EmptyBackendRxQ) if there was no available data recv_pkt<B: BitmapSlice>(&mut self, pkt: &mut VsockPacket<B>) -> Result<()>121 pub fn recv_pkt<B: BitmapSlice>(&mut self, pkt: &mut VsockPacket<B>) -> Result<()> { 122 // Pop an event from the backend_rxq 123 let key = self.backend_rxq.pop_front().ok_or(Error::EmptyBackendRxQ)?; 124 let conn = match self.conn_map.get_mut(&key) { 125 Some(conn) => conn, 126 None => { 127 // assume that the connection does not exist 128 return Ok(()); 129 } 130 }; 131 132 if conn.rx_queue.peek() == Some(RxOps::Reset) { 133 // Handle RST events here 134 let conn = self.conn_map.remove(&key).unwrap(); 135 self.listener_map.remove(&conn.stream.as_raw_fd()); 136 self.stream_map.remove(&conn.stream.as_raw_fd()); 137 self.local_port_set.remove(&conn.local_port); 138 VhostUserVsockThread::epoll_unregister(conn.epoll_fd, conn.stream.as_raw_fd()) 139 .unwrap_or_else(|err| { 140 warn!( 141 "Could not remove epoll listener for fd {:?}: {:?}", 142 conn.stream.as_raw_fd(), 143 err 144 ) 145 }); 146 147 // Initialize the packet header to contain a VSOCK_OP_RST operation 148 pkt.set_op(VSOCK_OP_RST) 149 .set_src_cid(VSOCK_HOST_CID) 150 .set_dst_cid(conn.guest_cid) 151 .set_src_port(conn.local_port) 152 .set_dst_port(conn.peer_port) 153 .set_len(0) 154 .set_type(VSOCK_TYPE_STREAM) 155 .set_flags(0) 156 .set_buf_alloc(0) 157 .set_fwd_cnt(0); 158 159 return Ok(()); 160 } 161 162 // Handle other packet types per connection 163 conn.recv_pkt(pkt)?; 164 165 Ok(()) 166 } 167 168 /// Deliver a guest generated packet to its destination in the backend. 169 /// 170 /// Absorbs unexpected packets, handles rest to respective connection 171 /// object. 172 /// 173 /// Returns: 174 /// - always `Ok(())` if packet has been consumed correctly send_pkt<B: BitmapSlice>(&mut self, pkt: &VsockPacket<B>) -> Result<()>175 pub fn send_pkt<B: BitmapSlice>(&mut self, pkt: &VsockPacket<B>) -> Result<()> { 176 if pkt.src_cid() != self.guest_cid { 177 warn!( 178 "vsock: dropping packet with inconsistent src_cid: {:?} from guest configured with CID: {:?}", 179 pkt.src_cid(), self.guest_cid 180 ); 181 return Ok(()); 182 } 183 184 let dst_cid = pkt.dst_cid(); 185 if dst_cid != VSOCK_HOST_CID { 186 let cid_map = self.cid_map.read().unwrap(); 187 if cid_map.contains_key(&dst_cid) { 188 let (sibling_raw_pkts_queue, sibling_groups_set, sibling_event_fd) = 189 cid_map.get(&dst_cid).unwrap(); 190 191 if self 192 .groups_set 193 .read() 194 .unwrap() 195 .is_disjoint(sibling_groups_set.read().unwrap().deref()) 196 { 197 info!( 198 "vsock: dropping packet for cid: {:?} due to group mismatch", 199 dst_cid 200 ); 201 return Ok(()); 202 } 203 204 sibling_raw_pkts_queue 205 .write() 206 .unwrap() 207 .push_back(RawVsockPacket::from_vsock_packet(pkt)?); 208 let _ = sibling_event_fd.write(1); 209 } else { 210 warn!("vsock: dropping packet for unknown cid: {:?}", dst_cid); 211 } 212 213 return Ok(()); 214 } 215 216 // TODO: Rst if packet has unsupported type 217 if pkt.type_() != VSOCK_TYPE_STREAM { 218 info!("vsock: dropping packet of unknown type"); 219 return Ok(()); 220 } 221 222 let key = ConnMapKey::new(pkt.dst_port(), pkt.src_port()); 223 224 // TODO: Handle cases where connection does not exist and packet op 225 // is not VSOCK_OP_REQUEST 226 if !self.conn_map.contains_key(&key) { 227 // The packet contains a new connection request 228 if pkt.op() == VSOCK_OP_REQUEST { 229 self.handle_new_guest_conn(pkt); 230 } else { 231 // TODO: send back RST 232 } 233 return Ok(()); 234 } 235 236 if pkt.op() == VSOCK_OP_RST { 237 // Handle an RST packet from the guest here 238 let conn = self.conn_map.get(&key).unwrap(); 239 if conn.rx_queue.contains(RxOps::Reset.bitmask()) { 240 return Ok(()); 241 } 242 let conn = self.conn_map.remove(&key).unwrap(); 243 self.listener_map.remove(&conn.stream.as_raw_fd()); 244 self.stream_map.remove(&conn.stream.as_raw_fd()); 245 self.local_port_set.remove(&conn.local_port); 246 VhostUserVsockThread::epoll_unregister(conn.epoll_fd, conn.stream.as_raw_fd()) 247 .unwrap_or_else(|err| { 248 warn!( 249 "Could not remove epoll listener for fd {:?}: {:?}", 250 conn.stream.as_raw_fd(), 251 err 252 ) 253 }); 254 return Ok(()); 255 } 256 257 // Forward this packet to its listening connection 258 let conn = self.conn_map.get_mut(&key).unwrap(); 259 conn.send_pkt(pkt)?; 260 261 if conn.rx_queue.pending_rx() { 262 // Required if the connection object adds new rx operations 263 self.backend_rxq.push_back(key); 264 } 265 266 Ok(()) 267 } 268 269 /// Deliver a raw vsock packet sent from a sibling VM to the guest vsock driver. 270 /// 271 /// Returns: 272 /// - `Ok(())` if packet was successfully filled in 273 /// - `Err(Error::EmptyRawPktsQueue)` if there was no available data recv_raw_pkt<B: BitmapSlice>(&mut self, pkt: &mut VsockPacket<B>) -> Result<()>274 pub fn recv_raw_pkt<B: BitmapSlice>(&mut self, pkt: &mut VsockPacket<B>) -> Result<()> { 275 let raw_vsock_pkt = self 276 .raw_pkts_queue 277 .write() 278 .unwrap() 279 .pop_front() 280 .ok_or(Error::EmptyRawPktsQueue)?; 281 282 pkt.set_header_from_raw(&raw_vsock_pkt.header).unwrap(); 283 if !raw_vsock_pkt.data.is_empty() { 284 let buf = pkt.data_slice().ok_or(Error::PktBufMissing)?; 285 buf.copy_from(&raw_vsock_pkt.data); 286 } 287 288 Ok(()) 289 } 290 291 /// Handle a new guest initiated connection, i.e from the peer, the guest driver. 292 /// 293 /// Attempts to connect to a host side unix socket listening on a path 294 /// corresponding to the destination port as follows: 295 /// - "{self.host_sock_path}_{local_port}"" handle_new_guest_conn<B: BitmapSlice>(&mut self, pkt: &VsockPacket<B>)296 fn handle_new_guest_conn<B: BitmapSlice>(&mut self, pkt: &VsockPacket<B>) { 297 let port_path = format!("{}_{}", self.host_socket_path, pkt.dst_port()); 298 299 UnixStream::connect(port_path) 300 .and_then(|stream| stream.set_nonblocking(true).map(|_| stream)) 301 .map_err(Error::UnixConnect) 302 .and_then(|stream| self.add_new_guest_conn(stream, pkt)) 303 .unwrap_or_else(|_| self.enq_rst()); 304 } 305 306 /// Wrapper to add new connection to relevant HashMaps. add_new_guest_conn<B: BitmapSlice>( &mut self, stream: UnixStream, pkt: &VsockPacket<B>, ) -> Result<()>307 fn add_new_guest_conn<B: BitmapSlice>( 308 &mut self, 309 stream: UnixStream, 310 pkt: &VsockPacket<B>, 311 ) -> Result<()> { 312 let conn = VsockConnection::new_peer_init( 313 stream.try_clone().map_err(Error::UnixConnect)?, 314 pkt.dst_cid(), 315 pkt.dst_port(), 316 pkt.src_cid(), 317 pkt.src_port(), 318 self.epoll_fd, 319 pkt.buf_alloc(), 320 self.tx_buffer_size, 321 ); 322 let stream_fd = conn.stream.as_raw_fd(); 323 self.listener_map 324 .insert(stream_fd, ConnMapKey::new(pkt.dst_port(), pkt.src_port())); 325 326 self.conn_map 327 .insert(ConnMapKey::new(pkt.dst_port(), pkt.src_port()), conn); 328 self.backend_rxq 329 .push_back(ConnMapKey::new(pkt.dst_port(), pkt.src_port())); 330 331 self.stream_map.insert(stream_fd, stream); 332 self.local_port_set.insert(pkt.dst_port()); 333 334 VhostUserVsockThread::epoll_register( 335 self.epoll_fd, 336 stream_fd, 337 epoll::Events::EPOLLIN | epoll::Events::EPOLLOUT, 338 )?; 339 Ok(()) 340 } 341 342 /// Enqueue RST packets to be sent to guest. enq_rst(&mut self)343 fn enq_rst(&mut self) { 344 // TODO 345 dbg!("New guest conn error: Enqueue RST"); 346 } 347 } 348 349 #[cfg(test)] 350 mod tests { 351 use super::*; 352 use crate::vhu_vsock::{VhostUserVsockBackend, VsockConfig, VSOCK_OP_RW}; 353 use std::os::unix::net::UnixListener; 354 use tempfile::tempdir; 355 use virtio_vsock::packet::{VsockPacket, PKT_HEADER_SIZE}; 356 357 const DATA_LEN: usize = 16; 358 const CONN_TX_BUF_SIZE: u32 = 64 * 1024; 359 const GROUP_NAME: &str = "default"; 360 361 #[test] test_vsock_thread_backend()362 fn test_vsock_thread_backend() { 363 const CID: u64 = 3; 364 const VSOCK_PEER_PORT: u32 = 1234; 365 366 let test_dir = tempdir().expect("Could not create a temp test directory."); 367 368 let vsock_socket_path = test_dir.path().join("test_vsock_thread_backend.vsock"); 369 let vsock_peer_path = test_dir.path().join("test_vsock_thread_backend.vsock_1234"); 370 371 let _listener = UnixListener::bind(&vsock_peer_path).unwrap(); 372 373 let epoll_fd = epoll::create(false).unwrap(); 374 375 let groups_set: HashSet<String> = vec![GROUP_NAME.to_string()].into_iter().collect(); 376 377 let cid_map: Arc<RwLock<CidMap>> = Arc::new(RwLock::new(HashMap::new())); 378 379 let mut vtp = VsockThreadBackend::new( 380 vsock_socket_path.display().to_string(), 381 epoll_fd, 382 CID, 383 CONN_TX_BUF_SIZE, 384 Arc::new(RwLock::new(groups_set)), 385 cid_map, 386 ); 387 388 assert!(!vtp.pending_rx()); 389 390 let mut pkt_raw = [0u8; PKT_HEADER_SIZE + DATA_LEN]; 391 let (hdr_raw, data_raw) = pkt_raw.split_at_mut(PKT_HEADER_SIZE); 392 393 // SAFETY: Safe as hdr_raw and data_raw are guaranteed to be valid. 394 let mut packet = unsafe { VsockPacket::new(hdr_raw, Some(data_raw)).unwrap() }; 395 396 assert_eq!( 397 vtp.recv_pkt(&mut packet).unwrap_err().to_string(), 398 Error::EmptyBackendRxQ.to_string() 399 ); 400 401 assert!(vtp.send_pkt(&packet).is_ok()); 402 403 packet.set_type(VSOCK_TYPE_STREAM); 404 assert!(vtp.send_pkt(&packet).is_ok()); 405 406 packet.set_src_cid(CID); 407 packet.set_dst_cid(VSOCK_HOST_CID); 408 packet.set_dst_port(VSOCK_PEER_PORT); 409 assert!(vtp.send_pkt(&packet).is_ok()); 410 411 packet.set_op(VSOCK_OP_REQUEST); 412 assert!(vtp.send_pkt(&packet).is_ok()); 413 414 packet.set_op(VSOCK_OP_RW); 415 assert!(vtp.send_pkt(&packet).is_ok()); 416 417 packet.set_op(VSOCK_OP_RST); 418 assert!(vtp.send_pkt(&packet).is_ok()); 419 420 assert!(vtp.recv_pkt(&mut packet).is_ok()); 421 422 // cleanup 423 let _ = std::fs::remove_file(&vsock_peer_path); 424 let _ = std::fs::remove_file(&vsock_socket_path); 425 426 test_dir.close().unwrap(); 427 } 428 429 #[test] test_vsock_thread_backend_sibling_vms()430 fn test_vsock_thread_backend_sibling_vms() { 431 const CID: u64 = 3; 432 const SIBLING_CID: u64 = 4; 433 const SIBLING_LISTENING_PORT: u32 = 1234; 434 435 let test_dir = tempdir().expect("Could not create a temp test directory."); 436 437 let vsock_socket_path = test_dir 438 .path() 439 .join("test_vsock_thread_backend.vsock") 440 .display() 441 .to_string(); 442 let sibling_vhost_socket_path = test_dir 443 .path() 444 .join("test_vsock_thread_backend_sibling.socket") 445 .display() 446 .to_string(); 447 let sibling_vsock_socket_path = test_dir 448 .path() 449 .join("test_vsock_thread_backend_sibling.vsock") 450 .display() 451 .to_string(); 452 453 let cid_map: Arc<RwLock<CidMap>> = Arc::new(RwLock::new(HashMap::new())); 454 455 let sibling_config = VsockConfig::new( 456 SIBLING_CID, 457 sibling_vhost_socket_path, 458 sibling_vsock_socket_path, 459 CONN_TX_BUF_SIZE, 460 vec!["group1", "group2", "group3"] 461 .into_iter() 462 .map(String::from) 463 .collect(), 464 ); 465 466 let sibling_backend = 467 Arc::new(VhostUserVsockBackend::new(sibling_config, cid_map.clone()).unwrap()); 468 469 let epoll_fd = epoll::create(false).unwrap(); 470 471 let groups_set: HashSet<String> = vec!["groupA", "groupB", "group3"] 472 .into_iter() 473 .map(String::from) 474 .collect(); 475 476 let mut vtp = VsockThreadBackend::new( 477 vsock_socket_path, 478 epoll_fd, 479 CID, 480 CONN_TX_BUF_SIZE, 481 Arc::new(RwLock::new(groups_set)), 482 cid_map, 483 ); 484 485 assert!(!vtp.pending_raw_pkts()); 486 487 let mut pkt_raw = [0u8; PKT_HEADER_SIZE + DATA_LEN]; 488 let (hdr_raw, data_raw) = pkt_raw.split_at_mut(PKT_HEADER_SIZE); 489 490 // SAFETY: Safe as hdr_raw and data_raw are guaranteed to be valid. 491 let mut packet = unsafe { VsockPacket::new(hdr_raw, Some(data_raw)).unwrap() }; 492 493 assert_eq!( 494 vtp.recv_raw_pkt(&mut packet).unwrap_err().to_string(), 495 Error::EmptyRawPktsQueue.to_string() 496 ); 497 498 packet.set_type(VSOCK_TYPE_STREAM); 499 packet.set_src_cid(CID); 500 packet.set_dst_cid(SIBLING_CID); 501 packet.set_dst_port(SIBLING_LISTENING_PORT); 502 packet.set_op(VSOCK_OP_RW); 503 packet.set_len(DATA_LEN as u32); 504 packet 505 .data_slice() 506 .unwrap() 507 .copy_from(&[0xCAu8, 0xFEu8, 0xBAu8, 0xBEu8]); 508 509 assert!(vtp.send_pkt(&packet).is_ok()); 510 assert!(sibling_backend.threads[0] 511 .lock() 512 .unwrap() 513 .thread_backend 514 .pending_raw_pkts()); 515 516 let mut recvd_pkt_raw = [0u8; PKT_HEADER_SIZE + DATA_LEN]; 517 let (recvd_hdr_raw, recvd_data_raw) = recvd_pkt_raw.split_at_mut(PKT_HEADER_SIZE); 518 519 let mut recvd_packet = 520 // SAFETY: Safe as recvd_hdr_raw and recvd_data_raw are guaranteed to be valid. 521 unsafe { VsockPacket::new(recvd_hdr_raw, Some(recvd_data_raw)).unwrap() }; 522 523 assert!(sibling_backend.threads[0] 524 .lock() 525 .unwrap() 526 .thread_backend 527 .recv_raw_pkt(&mut recvd_packet) 528 .is_ok()); 529 530 assert_eq!(recvd_packet.type_(), VSOCK_TYPE_STREAM); 531 assert_eq!(recvd_packet.src_cid(), CID); 532 assert_eq!(recvd_packet.dst_cid(), SIBLING_CID); 533 assert_eq!(recvd_packet.dst_port(), SIBLING_LISTENING_PORT); 534 assert_eq!(recvd_packet.op(), VSOCK_OP_RW); 535 assert_eq!(recvd_packet.len(), DATA_LEN as u32); 536 537 assert_eq!(recvd_data_raw[0], 0xCAu8); 538 assert_eq!(recvd_data_raw[1], 0xFEu8); 539 assert_eq!(recvd_data_raw[2], 0xBAu8); 540 assert_eq!(recvd_data_raw[3], 0xBEu8); 541 542 test_dir.close().unwrap(); 543 } 544 } 545