1 use super::{
2     protocol::VsockAddr, vsock::ConnectionInfo, DisconnectReason, SocketError, VirtIOSocket,
3     VsockEvent, VsockEventType, DEFAULT_RX_BUFFER_SIZE,
4 };
5 use crate::{transport::Transport, Hal, Result};
6 use alloc::{boxed::Box, vec::Vec};
7 use core::cmp::min;
8 use core::convert::TryInto;
9 use core::hint::spin_loop;
10 use log::debug;
11 use zerocopy::FromZeroes;
12 
13 const DEFAULT_PER_CONNECTION_BUFFER_CAPACITY: u32 = 1024;
14 
15 /// A higher level interface for VirtIO socket (vsock) devices.
16 ///
17 /// This keeps track of multiple vsock connections.
18 ///
19 /// `RX_BUFFER_SIZE` is the size in bytes of each buffer used in the RX virtqueue. This must be
20 /// bigger than `size_of::<VirtioVsockHdr>()`.
21 ///
22 /// # Example
23 ///
24 /// ```
25 /// # use virtio_drivers::{Error, Hal};
26 /// # use virtio_drivers::transport::Transport;
27 /// use virtio_drivers::device::socket::{VirtIOSocket, VsockAddr, VsockConnectionManager};
28 ///
29 /// # fn example<HalImpl: Hal, T: Transport>(transport: T) -> Result<(), Error> {
30 /// let mut socket = VsockConnectionManager::new(VirtIOSocket::<HalImpl, _>::new(transport)?);
31 ///
32 /// // Start a thread to call `socket.poll()` and handle events.
33 ///
34 /// let remote_address = VsockAddr { cid: 2, port: 42 };
35 /// let local_port = 1234;
36 /// socket.connect(remote_address, local_port)?;
37 ///
38 /// // Wait until `socket.poll()` returns an event indicating that the socket is connected.
39 ///
40 /// socket.send(remote_address, local_port, "Hello world".as_bytes())?;
41 ///
42 /// socket.shutdown(remote_address, local_port)?;
43 /// # Ok(())
44 /// # }
45 /// ```
46 pub struct VsockConnectionManager<
47     H: Hal,
48     T: Transport,
49     const RX_BUFFER_SIZE: usize = DEFAULT_RX_BUFFER_SIZE,
50 > {
51     driver: VirtIOSocket<H, T, RX_BUFFER_SIZE>,
52     per_connection_buffer_capacity: u32,
53     connections: Vec<Connection>,
54     listening_ports: Vec<u32>,
55 }
56 
57 #[derive(Debug)]
58 struct Connection {
59     info: ConnectionInfo,
60     buffer: RingBuffer,
61     /// The peer sent a SHUTDOWN request, but we haven't yet responded with a RST because there is
62     /// still data in the buffer.
63     peer_requested_shutdown: bool,
64 }
65 
66 impl Connection {
new(peer: VsockAddr, local_port: u32, buffer_capacity: u32) -> Self67     fn new(peer: VsockAddr, local_port: u32, buffer_capacity: u32) -> Self {
68         let mut info = ConnectionInfo::new(peer, local_port);
69         info.buf_alloc = buffer_capacity;
70         Self {
71             info,
72             buffer: RingBuffer::new(buffer_capacity.try_into().unwrap()),
73             peer_requested_shutdown: false,
74         }
75     }
76 }
77 
78 impl<H: Hal, T: Transport, const RX_BUFFER_SIZE: usize>
79     VsockConnectionManager<H, T, RX_BUFFER_SIZE>
80 {
81     /// Construct a new connection manager wrapping the given low-level VirtIO socket driver.
new(driver: VirtIOSocket<H, T, RX_BUFFER_SIZE>) -> Self82     pub fn new(driver: VirtIOSocket<H, T, RX_BUFFER_SIZE>) -> Self {
83         Self::new_with_capacity(driver, DEFAULT_PER_CONNECTION_BUFFER_CAPACITY)
84     }
85 
86     /// Construct a new connection manager wrapping the given low-level VirtIO socket driver, with
87     /// the given per-connection buffer capacity.
new_with_capacity( driver: VirtIOSocket<H, T, RX_BUFFER_SIZE>, per_connection_buffer_capacity: u32, ) -> Self88     pub fn new_with_capacity(
89         driver: VirtIOSocket<H, T, RX_BUFFER_SIZE>,
90         per_connection_buffer_capacity: u32,
91     ) -> Self {
92         Self {
93             driver,
94             connections: Vec::new(),
95             listening_ports: Vec::new(),
96             per_connection_buffer_capacity,
97         }
98     }
99 
100     /// Returns the CID which has been assigned to this guest.
guest_cid(&self) -> u64101     pub fn guest_cid(&self) -> u64 {
102         self.driver.guest_cid()
103     }
104 
105     /// Allows incoming connections on the given port number.
listen(&mut self, port: u32)106     pub fn listen(&mut self, port: u32) {
107         if !self.listening_ports.contains(&port) {
108             self.listening_ports.push(port);
109         }
110     }
111 
112     /// Stops allowing incoming connections on the given port number.
unlisten(&mut self, port: u32)113     pub fn unlisten(&mut self, port: u32) {
114         self.listening_ports.retain(|p| *p != port);
115     }
116 
117     /// Sends a request to connect to the given destination.
118     ///
119     /// This returns as soon as the request is sent; you should wait until `poll` returns a
120     /// `VsockEventType::Connected` event indicating that the peer has accepted the connection
121     /// before sending data.
connect(&mut self, destination: VsockAddr, src_port: u32) -> Result122     pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result {
123         if self.connections.iter().any(|connection| {
124             connection.info.dst == destination && connection.info.src_port == src_port
125         }) {
126             return Err(SocketError::ConnectionExists.into());
127         }
128 
129         let new_connection =
130             Connection::new(destination, src_port, self.per_connection_buffer_capacity);
131 
132         self.driver.connect(&new_connection.info)?;
133         debug!("Connection requested: {:?}", new_connection.info);
134         self.connections.push(new_connection);
135         Ok(())
136     }
137 
138     /// Sends the buffer to the destination.
send(&mut self, destination: VsockAddr, src_port: u32, buffer: &[u8]) -> Result139     pub fn send(&mut self, destination: VsockAddr, src_port: u32, buffer: &[u8]) -> Result {
140         let (_, connection) = get_connection(&mut self.connections, destination, src_port)?;
141 
142         self.driver.send(buffer, &mut connection.info)
143     }
144 
145     /// Polls the vsock device to receive data or other updates.
poll(&mut self) -> Result<Option<VsockEvent>>146     pub fn poll(&mut self) -> Result<Option<VsockEvent>> {
147         let guest_cid = self.driver.guest_cid();
148         let connections = &mut self.connections;
149         let per_connection_buffer_capacity = self.per_connection_buffer_capacity;
150 
151         let result = self.driver.poll(|event, body| {
152             let connection = get_connection_for_event(connections, &event, guest_cid);
153 
154             // Skip events which don't match any connection we know about, unless they are a
155             // connection request.
156             let connection = if let Some((_, connection)) = connection {
157                 connection
158             } else if let VsockEventType::ConnectionRequest = event.event_type {
159                 // If the requested connection already exists or the CID isn't ours, ignore it.
160                 if connection.is_some() || event.destination.cid != guest_cid {
161                     return Ok(None);
162                 }
163                 // Add the new connection to our list, at least for now. It will be removed again
164                 // below if we weren't listening on the port.
165                 connections.push(Connection::new(
166                     event.source,
167                     event.destination.port,
168                     per_connection_buffer_capacity,
169                 ));
170                 connections.last_mut().unwrap()
171             } else {
172                 return Ok(None);
173             };
174 
175             // Update stored connection info.
176             connection.info.update_for_event(&event);
177 
178             if let VsockEventType::Received { length } = event.event_type {
179                 // Copy to buffer
180                 if !connection.buffer.add(body) {
181                     return Err(SocketError::OutputBufferTooShort(length).into());
182                 }
183             }
184 
185             Ok(Some(event))
186         })?;
187 
188         let Some(event) = result else {
189             return Ok(None);
190         };
191 
192         // The connection must exist because we found it above in the callback.
193         let (connection_index, connection) =
194             get_connection_for_event(connections, &event, guest_cid).unwrap();
195 
196         match event.event_type {
197             VsockEventType::ConnectionRequest => {
198                 if self.listening_ports.contains(&event.destination.port) {
199                     self.driver.accept(&connection.info)?;
200                 } else {
201                     // Reject the connection request and remove it from our list.
202                     self.driver.force_close(&connection.info)?;
203                     self.connections.swap_remove(connection_index);
204 
205                     // No need to pass the request on to the client, as we've already rejected it.
206                     return Ok(None);
207                 }
208             }
209             VsockEventType::Connected => {}
210             VsockEventType::Disconnected { reason } => {
211                 // Wait until client reads all data before removing connection.
212                 if connection.buffer.is_empty() {
213                     if reason == DisconnectReason::Shutdown {
214                         self.driver.force_close(&connection.info)?;
215                     }
216                     self.connections.swap_remove(connection_index);
217                 } else {
218                     connection.peer_requested_shutdown = true;
219                 }
220             }
221             VsockEventType::Received { .. } => {
222                 // Already copied the buffer in the callback above.
223             }
224             VsockEventType::CreditRequest => {
225                 // If the peer requested credit, send an update.
226                 self.driver.credit_update(&connection.info)?;
227                 // No need to pass the request on to the client, we've already handled it.
228                 return Ok(None);
229             }
230             VsockEventType::CreditUpdate => {}
231         }
232 
233         Ok(Some(event))
234     }
235 
236     /// Reads data received from the given connection.
recv(&mut self, peer: VsockAddr, src_port: u32, buffer: &mut [u8]) -> Result<usize>237     pub fn recv(&mut self, peer: VsockAddr, src_port: u32, buffer: &mut [u8]) -> Result<usize> {
238         let (connection_index, connection) = get_connection(&mut self.connections, peer, src_port)?;
239 
240         // Copy from ring buffer
241         let bytes_read = connection.buffer.drain(buffer);
242 
243         connection.info.done_forwarding(bytes_read);
244 
245         // If buffer is now empty and the peer requested shutdown, finish shutting down the
246         // connection.
247         if connection.peer_requested_shutdown && connection.buffer.is_empty() {
248             self.driver.force_close(&connection.info)?;
249             self.connections.swap_remove(connection_index);
250         }
251 
252         Ok(bytes_read)
253     }
254 
255     /// Returns the number of bytes in the receive buffer available to be read by `recv`.
256     ///
257     /// When the available bytes is 0, it indicates that the receive buffer is empty and does not
258     /// contain any data.
recv_buffer_available_bytes(&mut self, peer: VsockAddr, src_port: u32) -> Result<usize>259     pub fn recv_buffer_available_bytes(&mut self, peer: VsockAddr, src_port: u32) -> Result<usize> {
260         let (_, connection) = get_connection(&mut self.connections, peer, src_port)?;
261         Ok(connection.buffer.used())
262     }
263 
264     /// Sends a credit update to the given peer.
update_credit(&mut self, peer: VsockAddr, src_port: u32) -> Result265     pub fn update_credit(&mut self, peer: VsockAddr, src_port: u32) -> Result {
266         let (_, connection) = get_connection(&mut self.connections, peer, src_port)?;
267         self.driver.credit_update(&connection.info)
268     }
269 
270     /// Blocks until we get some event from the vsock device.
wait_for_event(&mut self) -> Result<VsockEvent>271     pub fn wait_for_event(&mut self) -> Result<VsockEvent> {
272         loop {
273             if let Some(event) = self.poll()? {
274                 return Ok(event);
275             } else {
276                 spin_loop();
277             }
278         }
279     }
280 
281     /// Requests to shut down the connection cleanly, telling the peer that we won't send or receive
282     /// any more data.
283     ///
284     /// This returns as soon as the request is sent; you should wait until `poll` returns a
285     /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
286     /// shutdown.
shutdown(&mut self, destination: VsockAddr, src_port: u32) -> Result287     pub fn shutdown(&mut self, destination: VsockAddr, src_port: u32) -> Result {
288         let (_, connection) = get_connection(&mut self.connections, destination, src_port)?;
289 
290         self.driver.shutdown(&connection.info)
291     }
292 
293     /// Forcibly closes the connection without waiting for the peer.
force_close(&mut self, destination: VsockAddr, src_port: u32) -> Result294     pub fn force_close(&mut self, destination: VsockAddr, src_port: u32) -> Result {
295         let (index, connection) = get_connection(&mut self.connections, destination, src_port)?;
296 
297         self.driver.force_close(&connection.info)?;
298 
299         self.connections.swap_remove(index);
300         Ok(())
301     }
302 }
303 
304 /// Returns the connection from the given list matching the given peer address and local port, and
305 /// its index.
306 ///
307 /// Returns `Err(SocketError::NotConnected)` if there is no matching connection in the list.
get_connection( connections: &mut [Connection], peer: VsockAddr, local_port: u32, ) -> core::result::Result<(usize, &mut Connection), SocketError>308 fn get_connection(
309     connections: &mut [Connection],
310     peer: VsockAddr,
311     local_port: u32,
312 ) -> core::result::Result<(usize, &mut Connection), SocketError> {
313     connections
314         .iter_mut()
315         .enumerate()
316         .find(|(_, connection)| {
317             connection.info.dst == peer && connection.info.src_port == local_port
318         })
319         .ok_or(SocketError::NotConnected)
320 }
321 
322 /// Returns the connection from the given list matching the event, if any, and its index.
get_connection_for_event<'a>( connections: &'a mut [Connection], event: &VsockEvent, local_cid: u64, ) -> Option<(usize, &'a mut Connection)>323 fn get_connection_for_event<'a>(
324     connections: &'a mut [Connection],
325     event: &VsockEvent,
326     local_cid: u64,
327 ) -> Option<(usize, &'a mut Connection)> {
328     connections
329         .iter_mut()
330         .enumerate()
331         .find(|(_, connection)| event.matches_connection(&connection.info, local_cid))
332 }
333 
334 #[derive(Debug)]
335 struct RingBuffer {
336     buffer: Box<[u8]>,
337     /// The number of bytes currently in the buffer.
338     used: usize,
339     /// The index of the first used byte in the buffer.
340     start: usize,
341 }
342 
343 impl RingBuffer {
new(capacity: usize) -> Self344     pub fn new(capacity: usize) -> Self {
345         Self {
346             buffer: FromZeroes::new_box_slice_zeroed(capacity),
347             used: 0,
348             start: 0,
349         }
350     }
351 
352     /// Returns the number of bytes currently used in the buffer.
used(&self) -> usize353     pub fn used(&self) -> usize {
354         self.used
355     }
356 
357     /// Returns true iff there are currently no bytes in the buffer.
is_empty(&self) -> bool358     pub fn is_empty(&self) -> bool {
359         self.used == 0
360     }
361 
362     /// Returns the number of bytes currently free in the buffer.
free(&self) -> usize363     pub fn free(&self) -> usize {
364         self.buffer.len() - self.used
365     }
366 
367     /// Adds the given bytes to the buffer if there is enough capacity for them all.
368     ///
369     /// Returns true if they were added, or false if they were not.
add(&mut self, bytes: &[u8]) -> bool370     pub fn add(&mut self, bytes: &[u8]) -> bool {
371         if bytes.len() > self.free() {
372             return false;
373         }
374 
375         // The index of the first available position in the buffer.
376         let first_available = (self.start + self.used) % self.buffer.len();
377         // The number of bytes to copy from `bytes` to `buffer` between `first_available` and
378         // `buffer.len()`.
379         let copy_length_before_wraparound = min(bytes.len(), self.buffer.len() - first_available);
380         self.buffer[first_available..first_available + copy_length_before_wraparound]
381             .copy_from_slice(&bytes[0..copy_length_before_wraparound]);
382         if let Some(bytes_after_wraparound) = bytes.get(copy_length_before_wraparound..) {
383             self.buffer[0..bytes_after_wraparound.len()].copy_from_slice(bytes_after_wraparound);
384         }
385         self.used += bytes.len();
386 
387         true
388     }
389 
390     /// Reads and removes as many bytes as possible from the buffer, up to the length of the given
391     /// buffer.
drain(&mut self, out: &mut [u8]) -> usize392     pub fn drain(&mut self, out: &mut [u8]) -> usize {
393         let bytes_read = min(self.used, out.len());
394 
395         // The number of bytes to copy out between `start` and the end of the buffer.
396         let read_before_wraparound = min(bytes_read, self.buffer.len() - self.start);
397         // The number of bytes to copy out from the beginning of the buffer after wrapping around.
398         let read_after_wraparound = bytes_read
399             .checked_sub(read_before_wraparound)
400             .unwrap_or_default();
401 
402         out[0..read_before_wraparound]
403             .copy_from_slice(&self.buffer[self.start..self.start + read_before_wraparound]);
404         out[read_before_wraparound..bytes_read]
405             .copy_from_slice(&self.buffer[0..read_after_wraparound]);
406 
407         self.used -= bytes_read;
408         self.start = (self.start + bytes_read) % self.buffer.len();
409 
410         bytes_read
411     }
412 }
413 
414 #[cfg(test)]
415 mod tests {
416     use super::*;
417     use crate::{
418         device::socket::{
419             protocol::{
420                 SocketType, StreamShutdown, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp,
421             },
422             vsock::{VsockBufferStatus, QUEUE_SIZE, RX_QUEUE_IDX, TX_QUEUE_IDX},
423         },
424         hal::fake::FakeHal,
425         transport::{
426             fake::{FakeTransport, QueueStatus, State},
427             DeviceType,
428         },
429         volatile::ReadOnly,
430     };
431     use alloc::{sync::Arc, vec};
432     use core::{mem::size_of, ptr::NonNull};
433     use std::{sync::Mutex, thread};
434     use zerocopy::{AsBytes, FromBytes};
435 
436     #[test]
send_recv()437     fn send_recv() {
438         let host_cid = 2;
439         let guest_cid = 66;
440         let host_port = 1234;
441         let guest_port = 4321;
442         let host_address = VsockAddr {
443             cid: host_cid,
444             port: host_port,
445         };
446         let hello_from_guest = "Hello from guest";
447         let hello_from_host = "Hello from host";
448 
449         let mut config_space = VirtioVsockConfig {
450             guest_cid_low: ReadOnly::new(66),
451             guest_cid_high: ReadOnly::new(0),
452         };
453         let state = Arc::new(Mutex::new(State {
454             queues: vec![
455                 QueueStatus::default(),
456                 QueueStatus::default(),
457                 QueueStatus::default(),
458             ],
459             ..Default::default()
460         }));
461         let transport = FakeTransport {
462             device_type: DeviceType::Socket,
463             max_queue_size: 32,
464             device_features: 0,
465             config_space: NonNull::from(&mut config_space),
466             state: state.clone(),
467         };
468         let mut socket = VsockConnectionManager::new(
469             VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap(),
470         );
471 
472         // Start a thread to simulate the device.
473         let handle = thread::spawn(move || {
474             // Wait for connection request.
475             State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
476             assert_eq!(
477                 VirtioVsockHdr::read_from(
478                     state
479                         .lock()
480                         .unwrap()
481                         .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
482                         .as_slice()
483                 )
484                 .unwrap(),
485                 VirtioVsockHdr {
486                     op: VirtioVsockOp::Request.into(),
487                     src_cid: guest_cid.into(),
488                     dst_cid: host_cid.into(),
489                     src_port: guest_port.into(),
490                     dst_port: host_port.into(),
491                     len: 0.into(),
492                     socket_type: SocketType::Stream.into(),
493                     flags: 0.into(),
494                     buf_alloc: 1024.into(),
495                     fwd_cnt: 0.into(),
496                 }
497             );
498 
499             // Accept connection and give the peer enough credit to send the message.
500             state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
501                 RX_QUEUE_IDX,
502                 VirtioVsockHdr {
503                     op: VirtioVsockOp::Response.into(),
504                     src_cid: host_cid.into(),
505                     dst_cid: guest_cid.into(),
506                     src_port: host_port.into(),
507                     dst_port: guest_port.into(),
508                     len: 0.into(),
509                     socket_type: SocketType::Stream.into(),
510                     flags: 0.into(),
511                     buf_alloc: 50.into(),
512                     fwd_cnt: 0.into(),
513                 }
514                 .as_bytes(),
515             );
516 
517             // Expect the guest to send some data.
518             State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
519             let request = state
520                 .lock()
521                 .unwrap()
522                 .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX);
523             assert_eq!(
524                 request.len(),
525                 size_of::<VirtioVsockHdr>() + hello_from_guest.len()
526             );
527             assert_eq!(
528                 VirtioVsockHdr::read_from_prefix(request.as_slice()).unwrap(),
529                 VirtioVsockHdr {
530                     op: VirtioVsockOp::Rw.into(),
531                     src_cid: guest_cid.into(),
532                     dst_cid: host_cid.into(),
533                     src_port: guest_port.into(),
534                     dst_port: host_port.into(),
535                     len: (hello_from_guest.len() as u32).into(),
536                     socket_type: SocketType::Stream.into(),
537                     flags: 0.into(),
538                     buf_alloc: 1024.into(),
539                     fwd_cnt: 0.into(),
540                 }
541             );
542             assert_eq!(
543                 &request[size_of::<VirtioVsockHdr>()..],
544                 hello_from_guest.as_bytes()
545             );
546 
547             println!("Host sending");
548 
549             // Send a response.
550             let mut response = vec![0; size_of::<VirtioVsockHdr>() + hello_from_host.len()];
551             VirtioVsockHdr {
552                 op: VirtioVsockOp::Rw.into(),
553                 src_cid: host_cid.into(),
554                 dst_cid: guest_cid.into(),
555                 src_port: host_port.into(),
556                 dst_port: guest_port.into(),
557                 len: (hello_from_host.len() as u32).into(),
558                 socket_type: SocketType::Stream.into(),
559                 flags: 0.into(),
560                 buf_alloc: 50.into(),
561                 fwd_cnt: (hello_from_guest.len() as u32).into(),
562             }
563             .write_to_prefix(response.as_mut_slice());
564             response[size_of::<VirtioVsockHdr>()..].copy_from_slice(hello_from_host.as_bytes());
565             state
566                 .lock()
567                 .unwrap()
568                 .write_to_queue::<QUEUE_SIZE>(RX_QUEUE_IDX, &response);
569 
570             // Expect a shutdown.
571             State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
572             assert_eq!(
573                 VirtioVsockHdr::read_from(
574                     state
575                         .lock()
576                         .unwrap()
577                         .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
578                         .as_slice()
579                 )
580                 .unwrap(),
581                 VirtioVsockHdr {
582                     op: VirtioVsockOp::Shutdown.into(),
583                     src_cid: guest_cid.into(),
584                     dst_cid: host_cid.into(),
585                     src_port: guest_port.into(),
586                     dst_port: host_port.into(),
587                     len: 0.into(),
588                     socket_type: SocketType::Stream.into(),
589                     flags: (StreamShutdown::SEND | StreamShutdown::RECEIVE).into(),
590                     buf_alloc: 1024.into(),
591                     fwd_cnt: (hello_from_host.len() as u32).into(),
592                 }
593             );
594         });
595 
596         socket.connect(host_address, guest_port).unwrap();
597         assert_eq!(
598             socket.wait_for_event().unwrap(),
599             VsockEvent {
600                 source: host_address,
601                 destination: VsockAddr {
602                     cid: guest_cid,
603                     port: guest_port,
604                 },
605                 event_type: VsockEventType::Connected,
606                 buffer_status: VsockBufferStatus {
607                     buffer_allocation: 50,
608                     forward_count: 0,
609                 },
610             }
611         );
612         println!("Guest sending");
613         socket
614             .send(host_address, guest_port, "Hello from guest".as_bytes())
615             .unwrap();
616         println!("Guest waiting to receive.");
617         assert_eq!(
618             socket.wait_for_event().unwrap(),
619             VsockEvent {
620                 source: host_address,
621                 destination: VsockAddr {
622                     cid: guest_cid,
623                     port: guest_port,
624                 },
625                 event_type: VsockEventType::Received {
626                     length: hello_from_host.len()
627                 },
628                 buffer_status: VsockBufferStatus {
629                     buffer_allocation: 50,
630                     forward_count: hello_from_guest.len() as u32,
631                 },
632             }
633         );
634         println!("Guest getting received data.");
635         let mut buffer = [0u8; 64];
636         assert_eq!(
637             socket.recv(host_address, guest_port, &mut buffer).unwrap(),
638             hello_from_host.len()
639         );
640         assert_eq!(
641             &buffer[0..hello_from_host.len()],
642             hello_from_host.as_bytes()
643         );
644         socket.shutdown(host_address, guest_port).unwrap();
645 
646         handle.join().unwrap();
647     }
648 
649     #[test]
incoming_connection()650     fn incoming_connection() {
651         let host_cid = 2;
652         let guest_cid = 66;
653         let host_port = 1234;
654         let guest_port = 4321;
655         let wrong_guest_port = 4444;
656         let host_address = VsockAddr {
657             cid: host_cid,
658             port: host_port,
659         };
660 
661         let mut config_space = VirtioVsockConfig {
662             guest_cid_low: ReadOnly::new(66),
663             guest_cid_high: ReadOnly::new(0),
664         };
665         let state = Arc::new(Mutex::new(State {
666             queues: vec![
667                 QueueStatus::default(),
668                 QueueStatus::default(),
669                 QueueStatus::default(),
670             ],
671             ..Default::default()
672         }));
673         let transport = FakeTransport {
674             device_type: DeviceType::Socket,
675             max_queue_size: 32,
676             device_features: 0,
677             config_space: NonNull::from(&mut config_space),
678             state: state.clone(),
679         };
680         let mut socket = VsockConnectionManager::new(
681             VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap(),
682         );
683 
684         socket.listen(guest_port);
685 
686         // Start a thread to simulate the device.
687         let handle = thread::spawn(move || {
688             // Send a connection request for a port the guest isn't listening on.
689             println!("Host sending connection request to wrong port");
690             state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
691                 RX_QUEUE_IDX,
692                 VirtioVsockHdr {
693                     op: VirtioVsockOp::Request.into(),
694                     src_cid: host_cid.into(),
695                     dst_cid: guest_cid.into(),
696                     src_port: host_port.into(),
697                     dst_port: wrong_guest_port.into(),
698                     len: 0.into(),
699                     socket_type: SocketType::Stream.into(),
700                     flags: 0.into(),
701                     buf_alloc: 50.into(),
702                     fwd_cnt: 0.into(),
703                 }
704                 .as_bytes(),
705             );
706 
707             // Expect a rejection.
708             println!("Host waiting for rejection");
709             State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
710             assert_eq!(
711                 VirtioVsockHdr::read_from(
712                     state
713                         .lock()
714                         .unwrap()
715                         .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
716                         .as_slice()
717                 )
718                 .unwrap(),
719                 VirtioVsockHdr {
720                     op: VirtioVsockOp::Rst.into(),
721                     src_cid: guest_cid.into(),
722                     dst_cid: host_cid.into(),
723                     src_port: wrong_guest_port.into(),
724                     dst_port: host_port.into(),
725                     len: 0.into(),
726                     socket_type: SocketType::Stream.into(),
727                     flags: 0.into(),
728                     buf_alloc: 1024.into(),
729                     fwd_cnt: 0.into(),
730                 }
731             );
732 
733             // Send a connection request for a port the guest is listening on.
734             println!("Host sending connection request to right port");
735             state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
736                 RX_QUEUE_IDX,
737                 VirtioVsockHdr {
738                     op: VirtioVsockOp::Request.into(),
739                     src_cid: host_cid.into(),
740                     dst_cid: guest_cid.into(),
741                     src_port: host_port.into(),
742                     dst_port: guest_port.into(),
743                     len: 0.into(),
744                     socket_type: SocketType::Stream.into(),
745                     flags: 0.into(),
746                     buf_alloc: 50.into(),
747                     fwd_cnt: 0.into(),
748                 }
749                 .as_bytes(),
750             );
751 
752             // Expect a response.
753             println!("Host waiting for response");
754             State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
755             assert_eq!(
756                 VirtioVsockHdr::read_from(
757                     state
758                         .lock()
759                         .unwrap()
760                         .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
761                         .as_slice()
762                 )
763                 .unwrap(),
764                 VirtioVsockHdr {
765                     op: VirtioVsockOp::Response.into(),
766                     src_cid: guest_cid.into(),
767                     dst_cid: host_cid.into(),
768                     src_port: guest_port.into(),
769                     dst_port: host_port.into(),
770                     len: 0.into(),
771                     socket_type: SocketType::Stream.into(),
772                     flags: 0.into(),
773                     buf_alloc: 1024.into(),
774                     fwd_cnt: 0.into(),
775                 }
776             );
777 
778             println!("Host finished");
779         });
780 
781         // Expect an incoming connection.
782         println!("Guest expecting incoming connection.");
783         assert_eq!(
784             socket.wait_for_event().unwrap(),
785             VsockEvent {
786                 source: host_address,
787                 destination: VsockAddr {
788                     cid: guest_cid,
789                     port: guest_port,
790                 },
791                 event_type: VsockEventType::ConnectionRequest,
792                 buffer_status: VsockBufferStatus {
793                     buffer_allocation: 50,
794                     forward_count: 0,
795                 },
796             }
797         );
798 
799         handle.join().unwrap();
800     }
801 }
802