1 // SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause
2 
3 use std::{
4     collections::{HashMap, HashSet, VecDeque},
5     ops::Deref,
6     os::unix::{
7         net::UnixStream,
8         prelude::{AsRawFd, RawFd},
9     },
10     sync::{Arc, RwLock},
11 };
12 
13 use log::{info, warn};
14 use virtio_vsock::packet::{VsockPacket, PKT_HEADER_SIZE};
15 use vm_memory::bitmap::BitmapSlice;
16 
17 use crate::{
18     rxops::*,
19     vhu_vsock::{
20         CidMap, ConnMapKey, Error, Result, VSOCK_HOST_CID, VSOCK_OP_REQUEST, VSOCK_OP_RST,
21         VSOCK_TYPE_STREAM,
22     },
23     vhu_vsock_thread::VhostUserVsockThread,
24     vsock_conn::*,
25 };
26 
27 pub(crate) type RawPktsQ = VecDeque<RawVsockPacket>;
28 
29 pub(crate) struct RawVsockPacket {
30     pub header: [u8; PKT_HEADER_SIZE],
31     pub data: Vec<u8>,
32 }
33 
34 impl RawVsockPacket {
from_vsock_packet<B: BitmapSlice>(pkt: &VsockPacket<B>) -> Result<Self>35     fn from_vsock_packet<B: BitmapSlice>(pkt: &VsockPacket<B>) -> Result<Self> {
36         let mut raw_pkt = Self {
37             header: [0; PKT_HEADER_SIZE],
38             data: vec![0; pkt.len() as usize],
39         };
40 
41         pkt.header_slice().copy_to(&mut raw_pkt.header);
42         if !pkt.is_empty() {
43             pkt.data_slice()
44                 .ok_or(Error::PktBufMissing)?
45                 .copy_to(raw_pkt.data.as_mut());
46         }
47 
48         Ok(raw_pkt)
49     }
50 }
51 
52 pub(crate) struct VsockThreadBackend {
53     /// Map of ConnMapKey objects indexed by raw file descriptors.
54     pub listener_map: HashMap<RawFd, ConnMapKey>,
55     /// Map of vsock connection objects indexed by ConnMapKey objects.
56     pub conn_map: HashMap<ConnMapKey, VsockConnection<UnixStream>>,
57     /// Queue of ConnMapKey objects indicating pending rx operations.
58     pub backend_rxq: VecDeque<ConnMapKey>,
59     /// Map of host-side unix streams indexed by raw file descriptors.
60     pub stream_map: HashMap<i32, UnixStream>,
61     /// Host side socket for listening to new connections from the host.
62     host_socket_path: String,
63     /// epoll for registering new host-side connections.
64     epoll_fd: i32,
65     /// CID of the guest.
66     guest_cid: u64,
67     /// Set of allocated local ports.
68     pub local_port_set: HashSet<u32>,
69     tx_buffer_size: u32,
70     /// Maps the guest CID to the corresponding backend. Used for sibling VM communication.
71     pub cid_map: Arc<RwLock<CidMap>>,
72     /// Queue of raw vsock packets recieved from sibling VMs to be sent to the guest.
73     pub raw_pkts_queue: Arc<RwLock<RawPktsQ>>,
74     /// Set of groups assigned to the device which it is allowed to communicate with.
75     groups_set: Arc<RwLock<HashSet<String>>>,
76 }
77 
78 impl VsockThreadBackend {
79     /// New instance of VsockThreadBackend.
new( host_socket_path: String, epoll_fd: i32, guest_cid: u64, tx_buffer_size: u32, groups_set: Arc<RwLock<HashSet<String>>>, cid_map: Arc<RwLock<CidMap>>, ) -> Self80     pub fn new(
81         host_socket_path: String,
82         epoll_fd: i32,
83         guest_cid: u64,
84         tx_buffer_size: u32,
85         groups_set: Arc<RwLock<HashSet<String>>>,
86         cid_map: Arc<RwLock<CidMap>>,
87     ) -> Self {
88         Self {
89             listener_map: HashMap::new(),
90             conn_map: HashMap::new(),
91             backend_rxq: VecDeque::new(),
92             // Need this map to prevent connected stream from closing
93             // TODO: think of a better solution
94             stream_map: HashMap::new(),
95             host_socket_path,
96             epoll_fd,
97             guest_cid,
98             local_port_set: HashSet::new(),
99             tx_buffer_size,
100             cid_map,
101             raw_pkts_queue: Arc::new(RwLock::new(VecDeque::new())),
102             groups_set,
103         }
104     }
105 
106     /// Checks if there are pending rx requests in the backend rxq.
pending_rx(&self) -> bool107     pub fn pending_rx(&self) -> bool {
108         !self.backend_rxq.is_empty()
109     }
110 
111     /// Checks if there are pending raw vsock packets to be sent to the guest.
pending_raw_pkts(&self) -> bool112     pub fn pending_raw_pkts(&self) -> bool {
113         !self.raw_pkts_queue.read().unwrap().is_empty()
114     }
115 
116     /// Deliver a vsock packet to the guest vsock driver.
117     ///
118     /// Returns:
119     /// - `Ok(())` if the packet was successfully filled in
120     /// - `Err(Error::EmptyBackendRxQ) if there was no available data
recv_pkt<B: BitmapSlice>(&mut self, pkt: &mut VsockPacket<B>) -> Result<()>121     pub fn recv_pkt<B: BitmapSlice>(&mut self, pkt: &mut VsockPacket<B>) -> Result<()> {
122         // Pop an event from the backend_rxq
123         let key = self.backend_rxq.pop_front().ok_or(Error::EmptyBackendRxQ)?;
124         let conn = match self.conn_map.get_mut(&key) {
125             Some(conn) => conn,
126             None => {
127                 // assume that the connection does not exist
128                 return Ok(());
129             }
130         };
131 
132         if conn.rx_queue.peek() == Some(RxOps::Reset) {
133             // Handle RST events here
134             let conn = self.conn_map.remove(&key).unwrap();
135             self.listener_map.remove(&conn.stream.as_raw_fd());
136             self.stream_map.remove(&conn.stream.as_raw_fd());
137             self.local_port_set.remove(&conn.local_port);
138             VhostUserVsockThread::epoll_unregister(conn.epoll_fd, conn.stream.as_raw_fd())
139                 .unwrap_or_else(|err| {
140                     warn!(
141                         "Could not remove epoll listener for fd {:?}: {:?}",
142                         conn.stream.as_raw_fd(),
143                         err
144                     )
145                 });
146 
147             // Initialize the packet header to contain a VSOCK_OP_RST operation
148             pkt.set_op(VSOCK_OP_RST)
149                 .set_src_cid(VSOCK_HOST_CID)
150                 .set_dst_cid(conn.guest_cid)
151                 .set_src_port(conn.local_port)
152                 .set_dst_port(conn.peer_port)
153                 .set_len(0)
154                 .set_type(VSOCK_TYPE_STREAM)
155                 .set_flags(0)
156                 .set_buf_alloc(0)
157                 .set_fwd_cnt(0);
158 
159             return Ok(());
160         }
161 
162         // Handle other packet types per connection
163         conn.recv_pkt(pkt)?;
164 
165         Ok(())
166     }
167 
168     /// Deliver a guest generated packet to its destination in the backend.
169     ///
170     /// Absorbs unexpected packets, handles rest to respective connection
171     /// object.
172     ///
173     /// Returns:
174     /// - always `Ok(())` if packet has been consumed correctly
send_pkt<B: BitmapSlice>(&mut self, pkt: &VsockPacket<B>) -> Result<()>175     pub fn send_pkt<B: BitmapSlice>(&mut self, pkt: &VsockPacket<B>) -> Result<()> {
176         if pkt.src_cid() != self.guest_cid {
177             warn!(
178                 "vsock: dropping packet with inconsistent src_cid: {:?} from guest configured with CID: {:?}",
179                 pkt.src_cid(), self.guest_cid
180             );
181             return Ok(());
182         }
183 
184         let dst_cid = pkt.dst_cid();
185         if dst_cid != VSOCK_HOST_CID {
186             let cid_map = self.cid_map.read().unwrap();
187             if cid_map.contains_key(&dst_cid) {
188                 let (sibling_raw_pkts_queue, sibling_groups_set, sibling_event_fd) =
189                     cid_map.get(&dst_cid).unwrap();
190 
191                 if self
192                     .groups_set
193                     .read()
194                     .unwrap()
195                     .is_disjoint(sibling_groups_set.read().unwrap().deref())
196                 {
197                     info!(
198                         "vsock: dropping packet for cid: {:?} due to group mismatch",
199                         dst_cid
200                     );
201                     return Ok(());
202                 }
203 
204                 sibling_raw_pkts_queue
205                     .write()
206                     .unwrap()
207                     .push_back(RawVsockPacket::from_vsock_packet(pkt)?);
208                 let _ = sibling_event_fd.write(1);
209             } else {
210                 warn!("vsock: dropping packet for unknown cid: {:?}", dst_cid);
211             }
212 
213             return Ok(());
214         }
215 
216         // TODO: Rst if packet has unsupported type
217         if pkt.type_() != VSOCK_TYPE_STREAM {
218             info!("vsock: dropping packet of unknown type");
219             return Ok(());
220         }
221 
222         let key = ConnMapKey::new(pkt.dst_port(), pkt.src_port());
223 
224         // TODO: Handle cases where connection does not exist and packet op
225         // is not VSOCK_OP_REQUEST
226         if !self.conn_map.contains_key(&key) {
227             // The packet contains a new connection request
228             if pkt.op() == VSOCK_OP_REQUEST {
229                 self.handle_new_guest_conn(pkt);
230             } else {
231                 // TODO: send back RST
232             }
233             return Ok(());
234         }
235 
236         if pkt.op() == VSOCK_OP_RST {
237             // Handle an RST packet from the guest here
238             let conn = self.conn_map.get(&key).unwrap();
239             if conn.rx_queue.contains(RxOps::Reset.bitmask()) {
240                 return Ok(());
241             }
242             let conn = self.conn_map.remove(&key).unwrap();
243             self.listener_map.remove(&conn.stream.as_raw_fd());
244             self.stream_map.remove(&conn.stream.as_raw_fd());
245             self.local_port_set.remove(&conn.local_port);
246             VhostUserVsockThread::epoll_unregister(conn.epoll_fd, conn.stream.as_raw_fd())
247                 .unwrap_or_else(|err| {
248                     warn!(
249                         "Could not remove epoll listener for fd {:?}: {:?}",
250                         conn.stream.as_raw_fd(),
251                         err
252                     )
253                 });
254             return Ok(());
255         }
256 
257         // Forward this packet to its listening connection
258         let conn = self.conn_map.get_mut(&key).unwrap();
259         conn.send_pkt(pkt)?;
260 
261         if conn.rx_queue.pending_rx() {
262             // Required if the connection object adds new rx operations
263             self.backend_rxq.push_back(key);
264         }
265 
266         Ok(())
267     }
268 
269     /// Deliver a raw vsock packet sent from a sibling VM to the guest vsock driver.
270     ///
271     /// Returns:
272     /// - `Ok(())` if packet was successfully filled in
273     /// - `Err(Error::EmptyRawPktsQueue)` if there was no available data
recv_raw_pkt<B: BitmapSlice>(&mut self, pkt: &mut VsockPacket<B>) -> Result<()>274     pub fn recv_raw_pkt<B: BitmapSlice>(&mut self, pkt: &mut VsockPacket<B>) -> Result<()> {
275         let raw_vsock_pkt = self
276             .raw_pkts_queue
277             .write()
278             .unwrap()
279             .pop_front()
280             .ok_or(Error::EmptyRawPktsQueue)?;
281 
282         pkt.set_header_from_raw(&raw_vsock_pkt.header).unwrap();
283         if !raw_vsock_pkt.data.is_empty() {
284             let buf = pkt.data_slice().ok_or(Error::PktBufMissing)?;
285             buf.copy_from(&raw_vsock_pkt.data);
286         }
287 
288         Ok(())
289     }
290 
291     /// Handle a new guest initiated connection, i.e from the peer, the guest driver.
292     ///
293     /// Attempts to connect to a host side unix socket listening on a path
294     /// corresponding to the destination port as follows:
295     /// - "{self.host_sock_path}_{local_port}""
handle_new_guest_conn<B: BitmapSlice>(&mut self, pkt: &VsockPacket<B>)296     fn handle_new_guest_conn<B: BitmapSlice>(&mut self, pkt: &VsockPacket<B>) {
297         let port_path = format!("{}_{}", self.host_socket_path, pkt.dst_port());
298 
299         UnixStream::connect(port_path)
300             .and_then(|stream| stream.set_nonblocking(true).map(|_| stream))
301             .map_err(Error::UnixConnect)
302             .and_then(|stream| self.add_new_guest_conn(stream, pkt))
303             .unwrap_or_else(|_| self.enq_rst());
304     }
305 
306     /// Wrapper to add new connection to relevant HashMaps.
add_new_guest_conn<B: BitmapSlice>( &mut self, stream: UnixStream, pkt: &VsockPacket<B>, ) -> Result<()>307     fn add_new_guest_conn<B: BitmapSlice>(
308         &mut self,
309         stream: UnixStream,
310         pkt: &VsockPacket<B>,
311     ) -> Result<()> {
312         let conn = VsockConnection::new_peer_init(
313             stream.try_clone().map_err(Error::UnixConnect)?,
314             pkt.dst_cid(),
315             pkt.dst_port(),
316             pkt.src_cid(),
317             pkt.src_port(),
318             self.epoll_fd,
319             pkt.buf_alloc(),
320             self.tx_buffer_size,
321         );
322         let stream_fd = conn.stream.as_raw_fd();
323         self.listener_map
324             .insert(stream_fd, ConnMapKey::new(pkt.dst_port(), pkt.src_port()));
325 
326         self.conn_map
327             .insert(ConnMapKey::new(pkt.dst_port(), pkt.src_port()), conn);
328         self.backend_rxq
329             .push_back(ConnMapKey::new(pkt.dst_port(), pkt.src_port()));
330 
331         self.stream_map.insert(stream_fd, stream);
332         self.local_port_set.insert(pkt.dst_port());
333 
334         VhostUserVsockThread::epoll_register(
335             self.epoll_fd,
336             stream_fd,
337             epoll::Events::EPOLLIN | epoll::Events::EPOLLOUT,
338         )?;
339         Ok(())
340     }
341 
342     /// Enqueue RST packets to be sent to guest.
enq_rst(&mut self)343     fn enq_rst(&mut self) {
344         // TODO
345         dbg!("New guest conn error: Enqueue RST");
346     }
347 }
348 
349 #[cfg(test)]
350 mod tests {
351     use super::*;
352     use crate::vhu_vsock::{VhostUserVsockBackend, VsockConfig, VSOCK_OP_RW};
353     use std::os::unix::net::UnixListener;
354     use tempfile::tempdir;
355     use virtio_vsock::packet::{VsockPacket, PKT_HEADER_SIZE};
356 
357     const DATA_LEN: usize = 16;
358     const CONN_TX_BUF_SIZE: u32 = 64 * 1024;
359     const GROUP_NAME: &str = "default";
360 
361     #[test]
test_vsock_thread_backend()362     fn test_vsock_thread_backend() {
363         const CID: u64 = 3;
364         const VSOCK_PEER_PORT: u32 = 1234;
365 
366         let test_dir = tempdir().expect("Could not create a temp test directory.");
367 
368         let vsock_socket_path = test_dir.path().join("test_vsock_thread_backend.vsock");
369         let vsock_peer_path = test_dir.path().join("test_vsock_thread_backend.vsock_1234");
370 
371         let _listener = UnixListener::bind(&vsock_peer_path).unwrap();
372 
373         let epoll_fd = epoll::create(false).unwrap();
374 
375         let groups_set: HashSet<String> = vec![GROUP_NAME.to_string()].into_iter().collect();
376 
377         let cid_map: Arc<RwLock<CidMap>> = Arc::new(RwLock::new(HashMap::new()));
378 
379         let mut vtp = VsockThreadBackend::new(
380             vsock_socket_path.display().to_string(),
381             epoll_fd,
382             CID,
383             CONN_TX_BUF_SIZE,
384             Arc::new(RwLock::new(groups_set)),
385             cid_map,
386         );
387 
388         assert!(!vtp.pending_rx());
389 
390         let mut pkt_raw = [0u8; PKT_HEADER_SIZE + DATA_LEN];
391         let (hdr_raw, data_raw) = pkt_raw.split_at_mut(PKT_HEADER_SIZE);
392 
393         // SAFETY: Safe as hdr_raw and data_raw are guaranteed to be valid.
394         let mut packet = unsafe { VsockPacket::new(hdr_raw, Some(data_raw)).unwrap() };
395 
396         assert_eq!(
397             vtp.recv_pkt(&mut packet).unwrap_err().to_string(),
398             Error::EmptyBackendRxQ.to_string()
399         );
400 
401         assert!(vtp.send_pkt(&packet).is_ok());
402 
403         packet.set_type(VSOCK_TYPE_STREAM);
404         assert!(vtp.send_pkt(&packet).is_ok());
405 
406         packet.set_src_cid(CID);
407         packet.set_dst_cid(VSOCK_HOST_CID);
408         packet.set_dst_port(VSOCK_PEER_PORT);
409         assert!(vtp.send_pkt(&packet).is_ok());
410 
411         packet.set_op(VSOCK_OP_REQUEST);
412         assert!(vtp.send_pkt(&packet).is_ok());
413 
414         packet.set_op(VSOCK_OP_RW);
415         assert!(vtp.send_pkt(&packet).is_ok());
416 
417         packet.set_op(VSOCK_OP_RST);
418         assert!(vtp.send_pkt(&packet).is_ok());
419 
420         assert!(vtp.recv_pkt(&mut packet).is_ok());
421 
422         // cleanup
423         let _ = std::fs::remove_file(&vsock_peer_path);
424         let _ = std::fs::remove_file(&vsock_socket_path);
425 
426         test_dir.close().unwrap();
427     }
428 
429     #[test]
test_vsock_thread_backend_sibling_vms()430     fn test_vsock_thread_backend_sibling_vms() {
431         const CID: u64 = 3;
432         const SIBLING_CID: u64 = 4;
433         const SIBLING_LISTENING_PORT: u32 = 1234;
434 
435         let test_dir = tempdir().expect("Could not create a temp test directory.");
436 
437         let vsock_socket_path = test_dir
438             .path()
439             .join("test_vsock_thread_backend.vsock")
440             .display()
441             .to_string();
442         let sibling_vhost_socket_path = test_dir
443             .path()
444             .join("test_vsock_thread_backend_sibling.socket")
445             .display()
446             .to_string();
447         let sibling_vsock_socket_path = test_dir
448             .path()
449             .join("test_vsock_thread_backend_sibling.vsock")
450             .display()
451             .to_string();
452 
453         let cid_map: Arc<RwLock<CidMap>> = Arc::new(RwLock::new(HashMap::new()));
454 
455         let sibling_config = VsockConfig::new(
456             SIBLING_CID,
457             sibling_vhost_socket_path,
458             sibling_vsock_socket_path,
459             CONN_TX_BUF_SIZE,
460             vec!["group1", "group2", "group3"]
461                 .into_iter()
462                 .map(String::from)
463                 .collect(),
464         );
465 
466         let sibling_backend =
467             Arc::new(VhostUserVsockBackend::new(sibling_config, cid_map.clone()).unwrap());
468 
469         let epoll_fd = epoll::create(false).unwrap();
470 
471         let groups_set: HashSet<String> = vec!["groupA", "groupB", "group3"]
472             .into_iter()
473             .map(String::from)
474             .collect();
475 
476         let mut vtp = VsockThreadBackend::new(
477             vsock_socket_path,
478             epoll_fd,
479             CID,
480             CONN_TX_BUF_SIZE,
481             Arc::new(RwLock::new(groups_set)),
482             cid_map,
483         );
484 
485         assert!(!vtp.pending_raw_pkts());
486 
487         let mut pkt_raw = [0u8; PKT_HEADER_SIZE + DATA_LEN];
488         let (hdr_raw, data_raw) = pkt_raw.split_at_mut(PKT_HEADER_SIZE);
489 
490         // SAFETY: Safe as hdr_raw and data_raw are guaranteed to be valid.
491         let mut packet = unsafe { VsockPacket::new(hdr_raw, Some(data_raw)).unwrap() };
492 
493         assert_eq!(
494             vtp.recv_raw_pkt(&mut packet).unwrap_err().to_string(),
495             Error::EmptyRawPktsQueue.to_string()
496         );
497 
498         packet.set_type(VSOCK_TYPE_STREAM);
499         packet.set_src_cid(CID);
500         packet.set_dst_cid(SIBLING_CID);
501         packet.set_dst_port(SIBLING_LISTENING_PORT);
502         packet.set_op(VSOCK_OP_RW);
503         packet.set_len(DATA_LEN as u32);
504         packet
505             .data_slice()
506             .unwrap()
507             .copy_from(&[0xCAu8, 0xFEu8, 0xBAu8, 0xBEu8]);
508 
509         assert!(vtp.send_pkt(&packet).is_ok());
510         assert!(sibling_backend.threads[0]
511             .lock()
512             .unwrap()
513             .thread_backend
514             .pending_raw_pkts());
515 
516         let mut recvd_pkt_raw = [0u8; PKT_HEADER_SIZE + DATA_LEN];
517         let (recvd_hdr_raw, recvd_data_raw) = recvd_pkt_raw.split_at_mut(PKT_HEADER_SIZE);
518 
519         let mut recvd_packet =
520             // SAFETY: Safe as recvd_hdr_raw and recvd_data_raw are guaranteed to be valid.
521             unsafe { VsockPacket::new(recvd_hdr_raw, Some(recvd_data_raw)).unwrap() };
522 
523         assert!(sibling_backend.threads[0]
524             .lock()
525             .unwrap()
526             .thread_backend
527             .recv_raw_pkt(&mut recvd_packet)
528             .is_ok());
529 
530         assert_eq!(recvd_packet.type_(), VSOCK_TYPE_STREAM);
531         assert_eq!(recvd_packet.src_cid(), CID);
532         assert_eq!(recvd_packet.dst_cid(), SIBLING_CID);
533         assert_eq!(recvd_packet.dst_port(), SIBLING_LISTENING_PORT);
534         assert_eq!(recvd_packet.op(), VSOCK_OP_RW);
535         assert_eq!(recvd_packet.len(), DATA_LEN as u32);
536 
537         assert_eq!(recvd_data_raw[0], 0xCAu8);
538         assert_eq!(recvd_data_raw[1], 0xFEu8);
539         assert_eq!(recvd_data_raw[2], 0xBAu8);
540         assert_eq!(recvd_data_raw[3], 0xBEu8);
541 
542         test_dir.close().unwrap();
543     }
544 }
545