1 // SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause
2 
3 use std::{
4     collections::HashSet,
5     fs::File,
6     io,
7     io::Read,
8     iter::FromIterator,
9     num::Wrapping,
10     ops::Deref,
11     os::unix::{
12         net::{UnixListener, UnixStream},
13         prelude::{AsRawFd, FromRawFd, RawFd},
14     },
15     sync::mpsc::Sender,
16     sync::{mpsc, Arc, RwLock},
17     thread,
18 };
19 
20 use log::warn;
21 use vhost_user_backend::{VringEpollHandler, VringRwLock, VringT};
22 use virtio_queue::QueueOwnedT;
23 use virtio_vsock::packet::{VsockPacket, PKT_HEADER_SIZE};
24 use vm_memory::{GuestAddressSpace, GuestMemoryAtomic, GuestMemoryMmap};
25 use vmm_sys_util::{
26     epoll::EventSet,
27     eventfd::{EventFd, EFD_NONBLOCK},
28 };
29 
30 use crate::{
31     rxops::*,
32     thread_backend::*,
33     vhu_vsock::{
34         CidMap, ConnMapKey, Error, Result, VhostUserVsockBackend, BACKEND_EVENT, SIBLING_VM_EVENT,
35         VSOCK_HOST_CID,
36     },
37     vsock_conn::*,
38 };
39 
40 type ArcVhostBknd = Arc<VhostUserVsockBackend>;
41 
42 enum RxQueueType {
43     Standard,
44     RawPkts,
45 }
46 
47 // Data which is required by a worker handling event idx.
48 struct EventData {
49     vring: VringRwLock,
50     event_idx: bool,
51     head_idx: u16,
52     used_len: usize,
53 }
54 
55 pub(crate) struct VhostUserVsockThread {
56     /// Guest memory map.
57     pub mem: Option<GuestMemoryAtomic<GuestMemoryMmap>>,
58     /// VIRTIO_RING_F_EVENT_IDX.
59     pub event_idx: bool,
60     /// Host socket raw file descriptor.
61     host_sock: RawFd,
62     /// Host socket path
63     host_sock_path: String,
64     /// Listener listening for new connections on the host.
65     host_listener: UnixListener,
66     /// epoll fd to which new host connections are added.
67     epoll_file: File,
68     /// VsockThreadBackend instance.
69     pub thread_backend: VsockThreadBackend,
70     /// CID of the guest.
71     guest_cid: u64,
72     /// Channel to a worker which handles event idx.
73     sender: Sender<EventData>,
74     /// host side port on which application listens.
75     local_port: Wrapping<u32>,
76     /// The tx buffer size
77     tx_buffer_size: u32,
78     /// EventFd to notify this thread for custom events. Currently used to notify
79     /// this thread to process raw vsock packets sent from a sibling VM.
80     pub sibling_event_fd: EventFd,
81     /// Keeps track of which RX queue was processed first in the last iteration.
82     /// Used to alternate between the RX queues to prevent the starvation of one by the other.
83     last_processed: RxQueueType,
84 }
85 
86 impl VhostUserVsockThread {
87     /// Create a new instance of VhostUserVsockThread.
new( uds_path: String, guest_cid: u64, tx_buffer_size: u32, groups: Vec<String>, cid_map: Arc<RwLock<CidMap>>, ) -> Result<Self>88     pub fn new(
89         uds_path: String,
90         guest_cid: u64,
91         tx_buffer_size: u32,
92         groups: Vec<String>,
93         cid_map: Arc<RwLock<CidMap>>,
94     ) -> Result<Self> {
95         // TODO: better error handling, maybe add a param to force the unlink
96         let _ = std::fs::remove_file(uds_path.clone());
97         let host_sock = UnixListener::bind(&uds_path)
98             .and_then(|sock| sock.set_nonblocking(true).map(|_| sock))
99             .map_err(Error::UnixBind)?;
100 
101         let epoll_fd = epoll::create(true).map_err(Error::EpollFdCreate)?;
102         // SAFETY: Safe as the fd is guaranteed to be valid here.
103         let epoll_file = unsafe { File::from_raw_fd(epoll_fd) };
104 
105         let host_raw_fd = host_sock.as_raw_fd();
106 
107         let mut groups = groups;
108         let groups_set: Arc<RwLock<HashSet<String>>> =
109             Arc::new(RwLock::new(HashSet::from_iter(groups.drain(..))));
110 
111         let sibling_event_fd = EventFd::new(EFD_NONBLOCK).map_err(Error::EventFdCreate)?;
112 
113         let thread_backend = VsockThreadBackend::new(
114             uds_path.clone(),
115             epoll_fd,
116             guest_cid,
117             tx_buffer_size,
118             groups_set.clone(),
119             cid_map.clone(),
120         );
121 
122         {
123             let mut cid_map = cid_map.write().unwrap();
124             if cid_map.contains_key(&guest_cid) {
125                 return Err(Error::CidAlreadyInUse);
126             }
127 
128             cid_map.insert(
129                 guest_cid,
130                 (
131                     thread_backend.raw_pkts_queue.clone(),
132                     groups_set,
133                     sibling_event_fd.try_clone().unwrap(),
134                 ),
135             );
136         }
137         let (sender, receiver) = mpsc::channel::<EventData>();
138         thread::spawn(move || loop {
139             // TODO: Understand why doing the following in the background thread works.
140             // maybe we'd better have thread pool for the entire application if necessary.
141             let Ok(event_data) = receiver.recv() else {
142                 break;
143             };
144             Self::vring_handle_event(event_data);
145         });
146         let thread = VhostUserVsockThread {
147             mem: None,
148             event_idx: false,
149             host_sock: host_sock.as_raw_fd(),
150             host_sock_path: uds_path,
151             host_listener: host_sock,
152             epoll_file,
153             thread_backend,
154             guest_cid,
155             sender,
156             local_port: Wrapping(0),
157             tx_buffer_size,
158             sibling_event_fd,
159             last_processed: RxQueueType::Standard,
160         };
161 
162         VhostUserVsockThread::epoll_register(epoll_fd, host_raw_fd, epoll::Events::EPOLLIN)?;
163 
164         Ok(thread)
165     }
166 
vring_handle_event(event_data: EventData)167     fn vring_handle_event(event_data: EventData) {
168         if event_data.event_idx {
169             if event_data
170                 .vring
171                 .add_used(event_data.head_idx, event_data.used_len as u32)
172                 .is_err()
173             {
174                 warn!("Could not return used descriptors to ring");
175             }
176             match event_data.vring.needs_notification() {
177                 Err(_) => {
178                     warn!("Could not check if queue needs to be notified");
179                     event_data.vring.signal_used_queue().unwrap();
180                 }
181                 Ok(needs_notification) => {
182                     if needs_notification {
183                         event_data.vring.signal_used_queue().unwrap();
184                     }
185                 }
186             }
187         } else {
188             if event_data
189                 .vring
190                 .add_used(event_data.head_idx, event_data.used_len as u32)
191                 .is_err()
192             {
193                 warn!("Could not return used descriptors to ring");
194             }
195             event_data.vring.signal_used_queue().unwrap();
196         }
197     }
198     /// Register a file with an epoll to listen for events in evset.
epoll_register(epoll_fd: RawFd, fd: RawFd, evset: epoll::Events) -> Result<()>199     pub fn epoll_register(epoll_fd: RawFd, fd: RawFd, evset: epoll::Events) -> Result<()> {
200         epoll::ctl(
201             epoll_fd,
202             epoll::ControlOptions::EPOLL_CTL_ADD,
203             fd,
204             epoll::Event::new(evset, fd as u64),
205         )
206         .map_err(Error::EpollAdd)?;
207 
208         Ok(())
209     }
210 
211     /// Remove a file from the epoll.
epoll_unregister(epoll_fd: RawFd, fd: RawFd) -> Result<()>212     pub fn epoll_unregister(epoll_fd: RawFd, fd: RawFd) -> Result<()> {
213         epoll::ctl(
214             epoll_fd,
215             epoll::ControlOptions::EPOLL_CTL_DEL,
216             fd,
217             epoll::Event::new(epoll::Events::empty(), 0),
218         )
219         .map_err(Error::EpollRemove)?;
220 
221         Ok(())
222     }
223 
224     /// Modify the events we listen to for the fd in the epoll.
epoll_modify(epoll_fd: RawFd, fd: RawFd, evset: epoll::Events) -> Result<()>225     pub fn epoll_modify(epoll_fd: RawFd, fd: RawFd, evset: epoll::Events) -> Result<()> {
226         epoll::ctl(
227             epoll_fd,
228             epoll::ControlOptions::EPOLL_CTL_MOD,
229             fd,
230             epoll::Event::new(evset, fd as u64),
231         )
232         .map_err(Error::EpollModify)?;
233 
234         Ok(())
235     }
236 
237     /// Return raw file descriptor of the epoll file.
get_epoll_fd(&self) -> RawFd238     fn get_epoll_fd(&self) -> RawFd {
239         self.epoll_file.as_raw_fd()
240     }
241 
242     /// Register our listeners in the VringEpollHandler
register_listeners( &mut self, epoll_handler: Arc<VringEpollHandler<ArcVhostBknd, VringRwLock, ()>>, )243     pub fn register_listeners(
244         &mut self,
245         epoll_handler: Arc<VringEpollHandler<ArcVhostBknd, VringRwLock, ()>>,
246     ) {
247         epoll_handler
248             .register_listener(self.get_epoll_fd(), EventSet::IN, u64::from(BACKEND_EVENT))
249             .unwrap();
250         epoll_handler
251             .register_listener(
252                 self.sibling_event_fd.as_raw_fd(),
253                 EventSet::IN,
254                 u64::from(SIBLING_VM_EVENT),
255             )
256             .unwrap();
257     }
258 
259     /// Process a BACKEND_EVENT received by VhostUserVsockBackend.
process_backend_evt(&mut self, _evset: EventSet)260     pub fn process_backend_evt(&mut self, _evset: EventSet) {
261         let mut epoll_events = vec![epoll::Event::new(epoll::Events::empty(), 0); 32];
262         'epoll: loop {
263             match epoll::wait(self.epoll_file.as_raw_fd(), 0, epoll_events.as_mut_slice()) {
264                 Ok(ev_cnt) => {
265                     for evt in epoll_events.iter().take(ev_cnt) {
266                         self.handle_event(
267                             evt.data as RawFd,
268                             epoll::Events::from_bits(evt.events).unwrap(),
269                         );
270                     }
271                 }
272                 Err(e) => {
273                     if e.kind() == io::ErrorKind::Interrupted {
274                         continue;
275                     }
276                     warn!("failed to consume new epoll event");
277                 }
278             }
279             break 'epoll;
280         }
281     }
282 
283     /// Handle a BACKEND_EVENT by either accepting a new connection or
284     /// forwarding a request to the appropriate connection object.
handle_event(&mut self, fd: RawFd, evset: epoll::Events)285     fn handle_event(&mut self, fd: RawFd, evset: epoll::Events) {
286         if fd == self.host_sock {
287             // This is a new connection initiated by an application running on the host
288             let conn = self.host_listener.accept().map_err(Error::UnixAccept);
289             if self.mem.is_some() {
290                 conn.and_then(|(stream, _)| {
291                     stream
292                         .set_nonblocking(true)
293                         .map(|_| stream)
294                         .map_err(Error::UnixAccept)
295                 })
296                 .and_then(|stream| self.add_stream_listener(stream))
297                 .unwrap_or_else(|err| {
298                     warn!("Unable to accept new local connection: {:?}", err);
299                 });
300             } else {
301                 // If we aren't ready to process requests, accept and immediately close
302                 // the connection.
303                 conn.map(drop).unwrap_or_else(|err| {
304                     warn!("Error closing an incoming connection: {:?}", err);
305                 });
306             }
307         } else {
308             // Check if the stream represented by fd has already established a
309             // connection with the application running in the guest
310             if let std::collections::hash_map::Entry::Vacant(_) =
311                 self.thread_backend.listener_map.entry(fd)
312             {
313                 // New connection from the host
314                 if evset.bits() != epoll::Events::EPOLLIN.bits() {
315                     // Has to be EPOLLIN as it was not connected previously
316                     return;
317                 }
318                 let mut unix_stream = match self.thread_backend.stream_map.remove(&fd) {
319                     Some(uds) => uds,
320                     None => {
321                         warn!("Error while searching fd in the stream map");
322                         return;
323                     }
324                 };
325 
326                 // Local peer is sending a "connect PORT\n" command
327                 let peer_port = match Self::read_local_stream_port(&mut unix_stream) {
328                     Ok(port) => port,
329                     Err(err) => {
330                         warn!("Error while parsing \"connect PORT\n\" command: {:?}", err);
331                         return;
332                     }
333                 };
334 
335                 // Allocate a local port number
336                 let local_port = match self.allocate_local_port() {
337                     Ok(lp) => lp,
338                     Err(err) => {
339                         warn!("Error while allocating local port: {:?}", err);
340                         return;
341                     }
342                 };
343 
344                 // Insert the fd into the backend's maps
345                 self.thread_backend
346                     .listener_map
347                     .insert(fd, ConnMapKey::new(local_port, peer_port));
348 
349                 // Create a new connection object an enqueue a connection request
350                 // packet to be sent to the guest
351                 let conn_map_key = ConnMapKey::new(local_port, peer_port);
352                 let mut new_conn = VsockConnection::new_local_init(
353                     unix_stream,
354                     VSOCK_HOST_CID,
355                     local_port,
356                     self.guest_cid,
357                     peer_port,
358                     self.get_epoll_fd(),
359                     self.tx_buffer_size,
360                 );
361                 new_conn.rx_queue.enqueue(RxOps::Request);
362                 new_conn.set_peer_port(peer_port);
363 
364                 // Add connection object into the backend's maps
365                 self.thread_backend.conn_map.insert(conn_map_key, new_conn);
366 
367                 self.thread_backend
368                     .backend_rxq
369                     .push_back(ConnMapKey::new(local_port, peer_port));
370 
371                 // Re-register the fd to listen for EPOLLIN and EPOLLOUT events
372                 Self::epoll_modify(
373                     self.get_epoll_fd(),
374                     fd,
375                     epoll::Events::EPOLLIN | epoll::Events::EPOLLOUT,
376                 )
377                 .unwrap();
378             } else {
379                 // Previously connected connection
380                 let key = self.thread_backend.listener_map.get(&fd).unwrap();
381                 let conn = self.thread_backend.conn_map.get_mut(key).unwrap();
382 
383                 if evset.bits() == epoll::Events::EPOLLOUT.bits() {
384                     // Flush any remaining data from the tx buffer
385                     match conn.tx_buf.flush_to(&mut conn.stream) {
386                         Ok(cnt) => {
387                             if cnt > 0 {
388                                 conn.fwd_cnt += Wrapping(cnt as u32);
389                                 conn.rx_queue.enqueue(RxOps::CreditUpdate);
390                             }
391                             self.thread_backend
392                                 .backend_rxq
393                                 .push_back(ConnMapKey::new(conn.local_port, conn.peer_port));
394                         }
395                         Err(e) => {
396                             dbg!("Error: {:?}", e);
397                         }
398                     }
399                     return;
400                 }
401 
402                 // Unregister stream from the epoll, register when connection is
403                 // established with the guest
404                 Self::epoll_unregister(self.epoll_file.as_raw_fd(), fd).unwrap();
405 
406                 // Enqueue a read request
407                 conn.rx_queue.enqueue(RxOps::Rw);
408                 self.thread_backend
409                     .backend_rxq
410                     .push_back(ConnMapKey::new(conn.local_port, conn.peer_port));
411             }
412         }
413     }
414 
415     /// Allocate a new local port number.
allocate_local_port(&mut self) -> Result<u32>416     fn allocate_local_port(&mut self) -> Result<u32> {
417         // TODO: Improve space efficiency of this operation
418         // TODO: Reuse the conn_map HashMap
419         // TODO: Test this.
420         let mut alloc_local_port = self.local_port.0;
421         loop {
422             if !self
423                 .thread_backend
424                 .local_port_set
425                 .contains(&alloc_local_port)
426             {
427                 // The port set doesn't contain the newly allocated port number.
428                 self.local_port = Wrapping(alloc_local_port + 1);
429                 self.thread_backend.local_port_set.insert(alloc_local_port);
430                 return Ok(alloc_local_port);
431             } else {
432                 if alloc_local_port == self.local_port.0 {
433                     // We have exhausted our search and wrapped back to the current port number
434                     return Err(Error::NoFreeLocalPort);
435                 }
436                 alloc_local_port += 1;
437             }
438         }
439     }
440 
441     /// Read `CONNECT PORT_NUM\n` from the connected stream.
read_local_stream_port(stream: &mut UnixStream) -> Result<u32>442     fn read_local_stream_port(stream: &mut UnixStream) -> Result<u32> {
443         let mut buf = [0u8; 32];
444 
445         // Minimum number of bytes we should be able to read
446         // Corresponds to 'CONNECT 0\n'
447         const MIN_READ_LEN: usize = 10;
448 
449         // Read in the minimum number of bytes we can read
450         stream
451             .read_exact(&mut buf[..MIN_READ_LEN])
452             .map_err(Error::UnixRead)?;
453 
454         let mut read_len = MIN_READ_LEN;
455         while buf[read_len - 1] != b'\n' && read_len < buf.len() {
456             stream
457                 .read_exact(&mut buf[read_len..read_len + 1])
458                 .map_err(Error::UnixRead)?;
459             read_len += 1;
460         }
461 
462         let mut word_iter = std::str::from_utf8(&buf[..read_len])
463             .map_err(Error::ConvertFromUtf8)?
464             .split_whitespace();
465 
466         word_iter
467             .next()
468             .ok_or(Error::InvalidPortRequest)
469             .and_then(|word| {
470                 if word.to_lowercase() == "connect" {
471                     Ok(())
472                 } else {
473                     Err(Error::InvalidPortRequest)
474                 }
475             })
476             .and_then(|_| word_iter.next().ok_or(Error::InvalidPortRequest))
477             .and_then(|word| word.parse::<u32>().map_err(Error::ParseInteger))
478             .map_err(|e| Error::ReadStreamPort(Box::new(e)))
479     }
480 
481     /// Add a stream to epoll to listen for EPOLLIN events.
add_stream_listener(&mut self, stream: UnixStream) -> Result<()>482     fn add_stream_listener(&mut self, stream: UnixStream) -> Result<()> {
483         let stream_fd = stream.as_raw_fd();
484         self.thread_backend.stream_map.insert(stream_fd, stream);
485         VhostUserVsockThread::epoll_register(
486             self.get_epoll_fd(),
487             stream_fd,
488             epoll::Events::EPOLLIN,
489         )?;
490 
491         Ok(())
492     }
493 
494     /// Iterate over the rx queue and process rx requests.
process_rx_queue( &mut self, vring: &VringRwLock, rx_queue_type: RxQueueType, ) -> Result<bool>495     fn process_rx_queue(
496         &mut self,
497         vring: &VringRwLock,
498         rx_queue_type: RxQueueType,
499     ) -> Result<bool> {
500         let mut used_any = false;
501         let atomic_mem = match &self.mem {
502             Some(m) => m,
503             None => return Err(Error::NoMemoryConfigured),
504         };
505 
506         let mut vring_mut = vring.get_mut();
507 
508         let queue = vring_mut.get_queue_mut();
509 
510         while let Some(mut avail_desc) = queue
511             .iter(atomic_mem.memory())
512             .map_err(|_| Error::IterateQueue)?
513             .next()
514         {
515             used_any = true;
516             let mem = atomic_mem.clone().memory();
517 
518             let head_idx = avail_desc.head_index();
519             let used_len = match VsockPacket::from_rx_virtq_chain(
520                 mem.deref(),
521                 &mut avail_desc,
522                 self.tx_buffer_size,
523             ) {
524                 Ok(mut pkt) => {
525                     let recv_result = match rx_queue_type {
526                         RxQueueType::Standard => self.thread_backend.recv_pkt(&mut pkt),
527                         RxQueueType::RawPkts => self.thread_backend.recv_raw_pkt(&mut pkt),
528                     };
529 
530                     if recv_result.is_ok() {
531                         PKT_HEADER_SIZE + pkt.len() as usize
532                     } else {
533                         queue.iter(mem).unwrap().go_to_previous_position();
534                         break;
535                     }
536                 }
537                 Err(e) => {
538                     warn!("vsock: RX queue error: {:?}", e);
539                     0
540                 }
541             };
542 
543             let vring = vring.clone();
544             let event_idx = self.event_idx;
545             self.sender
546                 .send(EventData {
547                     vring,
548                     event_idx,
549                     head_idx,
550                     used_len,
551                 })
552                 .unwrap();
553 
554             match rx_queue_type {
555                 RxQueueType::Standard => {
556                     if !self.thread_backend.pending_rx() {
557                         break;
558                     }
559                 }
560                 RxQueueType::RawPkts => {
561                     if !self.thread_backend.pending_raw_pkts() {
562                         break;
563                     }
564                 }
565             }
566         }
567         Ok(used_any)
568     }
569 
570     /// Wrapper to process rx queue based on whether event idx is enabled or not.
process_unix_sockets(&mut self, vring: &VringRwLock, event_idx: bool) -> Result<bool>571     fn process_unix_sockets(&mut self, vring: &VringRwLock, event_idx: bool) -> Result<bool> {
572         if event_idx {
573             // To properly handle EVENT_IDX we need to keep calling
574             // process_rx_queue until it stops finding new requests
575             // on the queue, as vm-virtio's Queue implementation
576             // only checks avail_index once
577             loop {
578                 if !self.thread_backend.pending_rx() {
579                     break;
580                 }
581                 vring.disable_notification().unwrap();
582 
583                 self.process_rx_queue(vring, RxQueueType::Standard)?;
584                 if !vring.enable_notification().unwrap() {
585                     break;
586                 }
587             }
588         } else {
589             self.process_rx_queue(vring, RxQueueType::Standard)?;
590         }
591         Ok(false)
592     }
593 
594     /// Wrapper to process raw vsock packets queue based on whether event idx is enabled or not.
process_raw_pkts(&mut self, vring: &VringRwLock, event_idx: bool) -> Result<bool>595     pub fn process_raw_pkts(&mut self, vring: &VringRwLock, event_idx: bool) -> Result<bool> {
596         if event_idx {
597             loop {
598                 if !self.thread_backend.pending_raw_pkts() {
599                     break;
600                 }
601                 vring.disable_notification().unwrap();
602 
603                 self.process_rx_queue(vring, RxQueueType::RawPkts)?;
604                 if !vring.enable_notification().unwrap() {
605                     break;
606                 }
607             }
608         } else {
609             self.process_rx_queue(vring, RxQueueType::RawPkts)?;
610         }
611         Ok(false)
612     }
613 
process_rx(&mut self, vring: &VringRwLock, event_idx: bool) -> Result<bool>614     pub fn process_rx(&mut self, vring: &VringRwLock, event_idx: bool) -> Result<bool> {
615         match self.last_processed {
616             RxQueueType::Standard => {
617                 if self.thread_backend.pending_raw_pkts() {
618                     self.process_raw_pkts(vring, event_idx)?;
619                     self.last_processed = RxQueueType::RawPkts;
620                 }
621                 if self.thread_backend.pending_rx() {
622                     self.process_unix_sockets(vring, event_idx)?;
623                 }
624             }
625             RxQueueType::RawPkts => {
626                 if self.thread_backend.pending_rx() {
627                     self.process_unix_sockets(vring, event_idx)?;
628                     self.last_processed = RxQueueType::Standard;
629                 }
630                 if self.thread_backend.pending_raw_pkts() {
631                     self.process_raw_pkts(vring, event_idx)?;
632                 }
633             }
634         }
635         Ok(false)
636     }
637 
638     /// Process tx queue and send requests to the backend for processing.
process_tx_queue(&mut self, vring: &VringRwLock) -> Result<bool>639     fn process_tx_queue(&mut self, vring: &VringRwLock) -> Result<bool> {
640         let mut used_any = false;
641 
642         let atomic_mem = match &self.mem {
643             Some(m) => m,
644             None => return Err(Error::NoMemoryConfigured),
645         };
646 
647         while let Some(mut avail_desc) = vring
648             .get_mut()
649             .get_queue_mut()
650             .iter(atomic_mem.memory())
651             .map_err(|_| Error::IterateQueue)?
652             .next()
653         {
654             used_any = true;
655             let mem = atomic_mem.clone().memory();
656 
657             let head_idx = avail_desc.head_index();
658             let pkt = match VsockPacket::from_tx_virtq_chain(
659                 mem.deref(),
660                 &mut avail_desc,
661                 self.tx_buffer_size,
662             ) {
663                 Ok(pkt) => pkt,
664                 Err(e) => {
665                     dbg!("vsock: error reading TX packet: {:?}", e);
666                     continue;
667                 }
668             };
669 
670             if self.thread_backend.send_pkt(&pkt).is_err() {
671                 vring
672                     .get_mut()
673                     .get_queue_mut()
674                     .iter(mem)
675                     .unwrap()
676                     .go_to_previous_position();
677                 break;
678             }
679 
680             // TODO: Check if the protocol requires read length to be correct
681             let used_len = 0;
682 
683             let vring = vring.clone();
684             let event_idx = self.event_idx;
685             self.sender
686                 .send(EventData {
687                     vring,
688                     event_idx,
689                     head_idx,
690                     used_len,
691                 })
692                 .unwrap();
693         }
694 
695         Ok(used_any)
696     }
697 
698     /// Wrapper to process tx queue based on whether event idx is enabled or not.
process_tx(&mut self, vring_lock: &VringRwLock, event_idx: bool) -> Result<bool>699     pub fn process_tx(&mut self, vring_lock: &VringRwLock, event_idx: bool) -> Result<bool> {
700         if event_idx {
701             // To properly handle EVENT_IDX we need to keep calling
702             // process_rx_queue until it stops finding new requests
703             // on the queue, as vm-virtio's Queue implementation
704             // only checks avail_index once
705             loop {
706                 vring_lock.disable_notification().unwrap();
707                 self.process_tx_queue(vring_lock)?;
708                 if !vring_lock.enable_notification().unwrap() {
709                     break;
710                 }
711             }
712         } else {
713             self.process_tx_queue(vring_lock)?;
714         }
715         Ok(false)
716     }
717 }
718 
719 impl Drop for VhostUserVsockThread {
drop(&mut self)720     fn drop(&mut self) {
721         let _ = std::fs::remove_file(&self.host_sock_path);
722         self.thread_backend
723             .cid_map
724             .write()
725             .unwrap()
726             .remove(&self.guest_cid);
727     }
728 }
729 #[cfg(test)]
730 mod tests {
731     use super::*;
732     use std::collections::HashMap;
733     use tempfile::tempdir;
734     use vm_memory::GuestAddress;
735     use vmm_sys_util::eventfd::EventFd;
736 
737     const CONN_TX_BUF_SIZE: u32 = 64 * 1024;
738 
739     impl VhostUserVsockThread {
get_epoll_file(&self) -> &File740         fn get_epoll_file(&self) -> &File {
741             &self.epoll_file
742         }
743     }
744 
745     #[test]
test_vsock_thread()746     fn test_vsock_thread() {
747         let groups: Vec<String> = vec![String::from("default")];
748 
749         let cid_map: Arc<RwLock<CidMap>> = Arc::new(RwLock::new(HashMap::new()));
750 
751         let test_dir = tempdir().expect("Could not create a temp test directory.");
752 
753         let t = VhostUserVsockThread::new(
754             test_dir
755                 .path()
756                 .join("test_vsock_thread.vsock")
757                 .display()
758                 .to_string(),
759             3,
760             CONN_TX_BUF_SIZE,
761             groups,
762             cid_map,
763         );
764         assert!(t.is_ok());
765 
766         let mut t = t.unwrap();
767         let epoll_fd = t.get_epoll_file().as_raw_fd();
768 
769         let mem = GuestMemoryAtomic::new(
770             GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 0x10000)]).unwrap(),
771         );
772 
773         t.mem = Some(mem.clone());
774 
775         let dummy_fd = EventFd::new(0).unwrap();
776 
777         assert!(VhostUserVsockThread::epoll_register(
778             epoll_fd,
779             dummy_fd.as_raw_fd(),
780             epoll::Events::EPOLLOUT
781         )
782         .is_ok());
783         assert!(VhostUserVsockThread::epoll_modify(
784             epoll_fd,
785             dummy_fd.as_raw_fd(),
786             epoll::Events::EPOLLIN
787         )
788         .is_ok());
789         assert!(VhostUserVsockThread::epoll_unregister(epoll_fd, dummy_fd.as_raw_fd()).is_ok());
790         assert!(VhostUserVsockThread::epoll_register(
791             epoll_fd,
792             dummy_fd.as_raw_fd(),
793             epoll::Events::EPOLLIN
794         )
795         .is_ok());
796 
797         let vring = VringRwLock::new(mem, 0x1000).unwrap();
798         vring.set_queue_info(0x100, 0x200, 0x300).unwrap();
799         vring.set_queue_ready(true);
800 
801         assert!(t.process_tx(&vring, false).is_ok());
802         assert!(t.process_tx(&vring, true).is_ok());
803         // add backend_rxq to avoid that RX processing is skipped
804         t.thread_backend
805             .backend_rxq
806             .push_back(ConnMapKey::new(0, 0));
807         assert!(t.process_rx(&vring, false).is_ok());
808         assert!(t.process_rx(&vring, true).is_ok());
809 
810         dummy_fd.write(1).unwrap();
811 
812         t.process_backend_evt(EventSet::empty());
813 
814         test_dir.close().unwrap();
815     }
816 
817     #[test]
test_vsock_thread_failures()818     fn test_vsock_thread_failures() {
819         let groups: Vec<String> = vec![String::from("default")];
820 
821         let cid_map: Arc<RwLock<CidMap>> = Arc::new(RwLock::new(HashMap::new()));
822 
823         let test_dir = tempdir().expect("Could not create a temp test directory.");
824 
825         let t = VhostUserVsockThread::new(
826             "/sys/not_allowed.vsock".to_string(),
827             3,
828             CONN_TX_BUF_SIZE,
829             groups.clone(),
830             cid_map.clone(),
831         );
832         assert!(t.is_err());
833 
834         let vsock_socket_path = test_dir
835             .path()
836             .join("test_vsock_thread_failures.vsock")
837             .display()
838             .to_string();
839         let mut t = VhostUserVsockThread::new(
840             vsock_socket_path,
841             3,
842             CONN_TX_BUF_SIZE,
843             groups.clone(),
844             cid_map.clone(),
845         )
846         .unwrap();
847         assert!(VhostUserVsockThread::epoll_register(-1, -1, epoll::Events::EPOLLIN).is_err());
848         assert!(VhostUserVsockThread::epoll_modify(-1, -1, epoll::Events::EPOLLIN).is_err());
849         assert!(VhostUserVsockThread::epoll_unregister(-1, -1).is_err());
850 
851         let mem = GuestMemoryAtomic::new(
852             GuestMemoryMmap::<()>::from_ranges(&[(GuestAddress(0), 0x10000)]).unwrap(),
853         );
854 
855         let vring = VringRwLock::new(mem, 0x1000).unwrap();
856 
857         // memory is not configured, so processing TX should fail
858         assert!(t.process_tx(&vring, false).is_err());
859         assert!(t.process_tx(&vring, true).is_err());
860 
861         // add backend_rxq to avoid that RX processing is skipped
862         t.thread_backend
863             .backend_rxq
864             .push_back(ConnMapKey::new(0, 0));
865         assert!(t.process_rx(&vring, false).is_err());
866         assert!(t.process_rx(&vring, true).is_err());
867 
868         // trying to use a CID that is already in use should fail
869         let vsock_socket_path2 = test_dir
870             .path()
871             .join("test_vsock_thread_failures2.vsock")
872             .display()
873             .to_string();
874         let t2 =
875             VhostUserVsockThread::new(vsock_socket_path2, 3, CONN_TX_BUF_SIZE, groups, cid_map);
876         assert!(t2.is_err());
877 
878         test_dir.close().unwrap();
879     }
880 }
881