// Copyright 2023, The Android Open Source Project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. //! Supports for the communication between rialto and host. use crate::error::Result; use ciborium_io::{Read, Write}; use core::hint::spin_loop; use core::mem; use core::result; use log::info; use service_vm_comm::{Response, ServiceVmRequest}; use tinyvec::ArrayVec; use virtio_drivers::{ self, device::socket::{ SocketError, VirtIOSocket, VsockAddr, VsockConnectionManager, VsockEventType, }, transport::Transport, Hal, }; const WRITE_BUF_CAPACITY: usize = 512; pub struct VsockStream { connection_manager: VsockConnectionManager, /// Peer address. The same port is used on rialto and peer for convenience. peer_addr: VsockAddr, write_buf: ArrayVec<[u8; WRITE_BUF_CAPACITY]>, } impl VsockStream { pub fn new( socket_device_driver: VirtIOSocket, peer_addr: VsockAddr, ) -> virtio_drivers::Result { let mut vsock_stream = Self { connection_manager: VsockConnectionManager::new(socket_device_driver), peer_addr, write_buf: ArrayVec::default(), }; vsock_stream.connect()?; Ok(vsock_stream) } fn connect(&mut self) -> virtio_drivers::Result { self.connection_manager.connect(self.peer_addr, self.peer_addr.port)?; self.wait_for_connect()?; info!("Connected to the peer {:?}", self.peer_addr); Ok(()) } fn wait_for_connect(&mut self) -> virtio_drivers::Result { loop { if let Some(event) = self.poll_event_from_peer()? { match event { VsockEventType::Connected => return Ok(()), VsockEventType::Disconnected { .. } => { return Err(SocketError::ConnectionFailed.into()) } // We shouldn't receive the following event before the connection is // established. VsockEventType::ConnectionRequest | VsockEventType::Received { .. } => { return Err(SocketError::InvalidOperation.into()) } // We can receive credit requests and updates at any time. // This can be ignored as the connection manager handles them in poll(). VsockEventType::CreditRequest | VsockEventType::CreditUpdate => {} } } else { spin_loop(); } } } pub fn read_request(&mut self) -> Result { Ok(ciborium::from_reader(self)?) } pub fn write_response(&mut self, response: &Response) -> Result<()> { Ok(ciborium::into_writer(response, self)?) } /// Shuts down the data channel. pub fn shutdown(&mut self) -> virtio_drivers::Result { self.connection_manager.force_close(self.peer_addr, self.peer_addr.port)?; info!("Connection shutdown."); Ok(()) } fn recv(&mut self, buffer: &mut [u8]) -> virtio_drivers::Result { let bytes_read = self.connection_manager.recv(self.peer_addr, self.peer_addr.port, buffer)?; let buffer_available_bytes = self .connection_manager .recv_buffer_available_bytes(self.peer_addr, self.peer_addr.port)?; if buffer_available_bytes == 0 && bytes_read > 0 { self.connection_manager.update_credit(self.peer_addr, self.peer_addr.port)?; } Ok(bytes_read) } fn wait_for_send(&mut self, buffer: &[u8]) -> virtio_drivers::Result { const INSUFFICIENT_BUFFER_SPACE_ERROR: virtio_drivers::Error = virtio_drivers::Error::SocketDeviceError(SocketError::InsufficientBufferSpaceInPeer); loop { match self.connection_manager.send(self.peer_addr, self.peer_addr.port, buffer) { Ok(_) => return Ok(()), Err(INSUFFICIENT_BUFFER_SPACE_ERROR) => { self.poll()?; } Err(e) => return Err(e), } } } fn wait_for_recv(&mut self) -> virtio_drivers::Result { loop { match self.poll()? { Some(VsockEventType::Received { .. }) => return Ok(()), _ => spin_loop(), } } } /// Polls the rx queue after the connection is established with the peer, this function /// rejects some invalid events. The valid events are handled inside the connection /// manager. fn poll(&mut self) -> virtio_drivers::Result> { if let Some(event) = self.poll_event_from_peer()? { match event { VsockEventType::Disconnected { .. } => Err(SocketError::ConnectionFailed.into()), VsockEventType::Connected | VsockEventType::ConnectionRequest => { Err(SocketError::InvalidOperation.into()) } // When there is a received event, the received data is buffered in the // connection manager's internal receive buffer, so we don't need to do // anything here. // The credit request and updates also handled inside the connection // manager. VsockEventType::Received { .. } | VsockEventType::CreditRequest | VsockEventType::CreditUpdate => Ok(Some(event)), } } else { Ok(None) } } fn poll_event_from_peer(&mut self) -> virtio_drivers::Result> { Ok(self.connection_manager.poll()?.map(|event| { assert_eq!(event.source, self.peer_addr); assert_eq!(event.destination.port, self.peer_addr.port); event.event_type })) } } impl Read for VsockStream { type Error = virtio_drivers::Error; fn read_exact(&mut self, data: &mut [u8]) -> result::Result<(), Self::Error> { let mut start = 0; while start < data.len() { let len = self.recv(&mut data[start..])?; let len = if len == 0 { self.wait_for_recv()?; self.recv(&mut data[start..])? } else { len }; start += len; } Ok(()) } } impl Write for VsockStream { type Error = virtio_drivers::Error; fn write_all(&mut self, data: &[u8]) -> result::Result<(), Self::Error> { if data.len() >= self.write_buf.capacity() - self.write_buf.len() { self.flush()?; if data.len() >= self.write_buf.capacity() { self.wait_for_send(data)?; return Ok(()); } } self.write_buf.extend_from_slice(data); Ok(()) } fn flush(&mut self) -> result::Result<(), Self::Error> { if !self.write_buf.is_empty() { // We need to take the memory from self.write_buf to a temporary // buffer to avoid borrowing `*self` as mutable and immutable on // the same time in `self.wait_for_send(&self.write_buf)`. let buffer = mem::take(&mut self.write_buf); self.wait_for_send(&buffer)?; } Ok(()) } }