// Copyright 2019 The ChromiumOS Authors // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. use std::cmp; use std::io; use std::io::Write; use std::iter::FromIterator; use std::marker::PhantomData; use std::mem::size_of; use std::mem::MaybeUninit; use std::ptr::copy_nonoverlapping; use std::sync::Arc; use anyhow::Context; use base::FileReadWriteAtVolatile; use base::FileReadWriteVolatile; use base::VolatileSlice; use cros_async::MemRegion; use cros_async::MemRegionIter; use data_model::Le16; use data_model::Le32; use data_model::Le64; use disk::AsyncDisk; use smallvec::SmallVec; use vm_memory::GuestAddress; use vm_memory::GuestMemory; use zerocopy::AsBytes; use zerocopy::FromBytes; use zerocopy::FromZeroes; use super::DescriptorChain; use crate::virtio::SplitDescriptorChain; struct DescriptorChainRegions { regions: SmallVec<[MemRegion; 2]>, // Index of the current region in `regions`. current_region_index: usize, // Number of bytes consumed in the current region. current_region_offset: usize, // Total bytes consumed in the entire descriptor chain. bytes_consumed: usize, } impl DescriptorChainRegions { fn new(regions: SmallVec<[MemRegion; 2]>) -> Self { DescriptorChainRegions { regions, current_region_index: 0, current_region_offset: 0, bytes_consumed: 0, } } fn available_bytes(&self) -> usize { // This is guaranteed not to overflow because the total length of the chain is checked // during all creations of `DescriptorChain` (see `DescriptorChain::new()`). self.get_remaining_regions() .fold(0usize, |count, region| count + region.len) } fn bytes_consumed(&self) -> usize { self.bytes_consumed } /// Returns all the remaining buffers in the `DescriptorChain`. Calling this function does not /// consume any bytes from the `DescriptorChain`. Instead callers should use the `consume` /// method to advance the `DescriptorChain`. Multiple calls to `get` with no intervening calls /// to `consume` will return the same data. fn get_remaining_regions(&self) -> MemRegionIter { MemRegionIter::new(&self.regions[self.current_region_index..]) .skip_bytes(self.current_region_offset) } /// Like `get_remaining_regions` but guarantees that the combined length of all the returned /// iovecs is not greater than `count`. The combined length of the returned iovecs may be less /// than `count` but will always be greater than 0 as long as there is still space left in the /// `DescriptorChain`. fn get_remaining_regions_with_count(&self, count: usize) -> MemRegionIter { MemRegionIter::new(&self.regions[self.current_region_index..]) .skip_bytes(self.current_region_offset) .take_bytes(count) } /// Returns all the remaining buffers in the `DescriptorChain` as `VolatileSlice`s of the given /// `GuestMemory`. Calling this function does not consume any bytes from the `DescriptorChain`. /// Instead callers should use the `consume` method to advance the `DescriptorChain`. Multiple /// calls to `get` with no intervening calls to `consume` will return the same data. fn get_remaining<'mem>(&self, mem: &'mem GuestMemory) -> SmallVec<[VolatileSlice<'mem>; 16]> { self.get_remaining_regions() .filter_map(|region| { mem.get_slice_at_addr(GuestAddress(region.offset), region.len) .ok() }) .collect() } /// Like 'get_remaining_regions_with_count' except convert the offsets to volatile slices in /// the 'GuestMemory' given by 'mem'. fn get_remaining_with_count<'mem>( &self, mem: &'mem GuestMemory, count: usize, ) -> SmallVec<[VolatileSlice<'mem>; 16]> { self.get_remaining_regions_with_count(count) .filter_map(|region| { mem.get_slice_at_addr(GuestAddress(region.offset), region.len) .ok() }) .collect() } /// Consumes `count` bytes from the `DescriptorChain`. If `count` is larger than /// `self.available_bytes()` then all remaining bytes in the `DescriptorChain` will be consumed. fn consume(&mut self, mut count: usize) { while let Some(region) = self.regions.get(self.current_region_index) { let region_remaining = region.len - self.current_region_offset; if count < region_remaining { // The remaining count to consume is less than the remaining un-consumed length of // the current region. Adjust the region offset without advancing to the next region // and stop. self.current_region_offset += count; self.bytes_consumed += count; return; } // The current region has been exhausted. Advance to the next region. self.current_region_index += 1; self.current_region_offset = 0; self.bytes_consumed += region_remaining; count -= region_remaining; } } fn split_at(&mut self, offset: usize) -> DescriptorChainRegions { let mut other = DescriptorChainRegions { regions: self.regions.clone(), current_region_index: self.current_region_index, current_region_offset: self.current_region_offset, bytes_consumed: self.bytes_consumed, }; other.consume(offset); other.bytes_consumed = 0; let mut rem = offset; let mut end = self.current_region_index; for region in &mut self.regions[self.current_region_index..] { if rem <= region.len { region.len = rem; break; } end += 1; rem -= region.len; } self.regions.truncate(end + 1); other } } /// Provides high-level interface over the sequence of memory regions /// defined by readable descriptors in the descriptor chain. /// /// Note that virtio spec requires driver to place any device-writable /// descriptors after any device-readable descriptors (2.6.4.2 in Virtio Spec v1.1). /// Reader will skip iterating over descriptor chain when first writable /// descriptor is encountered. pub struct Reader { mem: GuestMemory, regions: DescriptorChainRegions, } /// An iterator over `FromBytes` objects on readable descriptors in the descriptor chain. pub struct ReaderIterator<'a, T: FromBytes> { reader: &'a mut Reader, phantom: PhantomData, } impl<'a, T: FromBytes> Iterator for ReaderIterator<'a, T> { type Item = io::Result; fn next(&mut self) -> Option> { if self.reader.available_bytes() == 0 { None } else { Some(self.reader.read_obj()) } } } impl Reader { /// Construct a new Reader wrapper over `readable_regions`. pub fn new_from_regions( mem: &GuestMemory, readable_regions: SmallVec<[MemRegion; 2]>, ) -> Reader { Reader { mem: mem.clone(), regions: DescriptorChainRegions::new(readable_regions), } } /// Reads an object from the descriptor chain buffer without consuming it. pub fn peek_obj(&self) -> io::Result { let mut obj = MaybeUninit::uninit(); // SAFETY: We pass a valid pointer and size of `obj`. let copied = unsafe { copy_regions_to_mut_ptr( &self.mem, self.get_remaining_regions(), obj.as_mut_ptr() as *mut u8, size_of::(), )? }; if copied != size_of::() { return Err(io::Error::from(io::ErrorKind::UnexpectedEof)); } // SAFETY: `FromBytes` guarantees any set of initialized bytes is a valid value for `T`, and // we initialized all bytes in `obj` in the copy above. Ok(unsafe { obj.assume_init() }) } /// Reads and consumes an object from the descriptor chain buffer. pub fn read_obj(&mut self) -> io::Result { let obj = self.peek_obj::()?; self.consume(size_of::()); Ok(obj) } /// Reads objects by consuming all the remaining data in the descriptor chain buffer and returns /// them as a collection. Returns an error if the size of the remaining data is indivisible by /// the size of an object of type `T`. pub fn collect>, T: FromBytes>(&mut self) -> C { self.iter().collect() } /// Creates an iterator for sequentially reading `FromBytes` objects from the `Reader`. /// Unlike `collect`, this doesn't consume all the remaining data in the `Reader` and /// doesn't require the objects to be stored in a separate collection. pub fn iter(&mut self) -> ReaderIterator { ReaderIterator { reader: self, phantom: PhantomData, } } /// Reads data into a volatile slice up to the minimum of the slice's length or the number of /// bytes remaining. Returns the number of bytes read. pub fn read_to_volatile_slice(&mut self, slice: VolatileSlice) -> usize { let mut read = 0usize; let mut dst = slice; for src in self.get_remaining() { src.copy_to_volatile_slice(dst); let copied = std::cmp::min(src.size(), dst.size()); read += copied; dst = match dst.offset(copied) { Ok(v) => v, Err(_) => break, // The slice is fully consumed }; } self.regions.consume(read); read } /// Reads data from the descriptor chain buffer and passes the `VolatileSlice`s to the callback /// `cb`. pub fn read_to_cb usize>( &mut self, cb: C, count: usize, ) -> usize { let iovs = self.regions.get_remaining_with_count(&self.mem, count); let written = cb(&iovs[..]); self.regions.consume(written); written } /// Reads data from the descriptor chain buffer into a writable object. /// Returns the number of bytes read from the descriptor chain buffer. /// The number of bytes read can be less than `count` if there isn't /// enough data in the descriptor chain buffer. pub fn read_to( &mut self, mut dst: F, count: usize, ) -> io::Result { let iovs = self.regions.get_remaining_with_count(&self.mem, count); let written = dst.write_vectored_volatile(&iovs[..])?; self.regions.consume(written); Ok(written) } /// Reads data from the descriptor chain buffer into a File at offset `off`. /// Returns the number of bytes read from the descriptor chain buffer. /// The number of bytes read can be less than `count` if there isn't /// enough data in the descriptor chain buffer. pub fn read_to_at( &mut self, dst: &F, count: usize, off: u64, ) -> io::Result { let iovs = self.regions.get_remaining_with_count(&self.mem, count); let written = dst.write_vectored_at_volatile(&iovs[..], off)?; self.regions.consume(written); Ok(written) } /// Reads data from the descriptor chain similar to 'read_to' except reading 'count' or /// returning an error if 'count' bytes can't be read. pub fn read_exact_to( &mut self, mut dst: F, mut count: usize, ) -> io::Result<()> { while count > 0 { match self.read_to(&mut dst, count) { Ok(0) => { return Err(io::Error::new( io::ErrorKind::UnexpectedEof, "failed to fill whole buffer", )) } Ok(n) => count -= n, Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {} Err(e) => return Err(e), } } Ok(()) } /// Reads data from the descriptor chain similar to 'read_to_at' except reading 'count' or /// returning an error if 'count' bytes can't be read. pub fn read_exact_to_at( &mut self, dst: &F, mut count: usize, mut off: u64, ) -> io::Result<()> { while count > 0 { match self.read_to_at(dst, count, off) { Ok(0) => { return Err(io::Error::new( io::ErrorKind::UnexpectedEof, "failed to fill whole buffer", )) } Ok(n) => { count -= n; off += n as u64; } Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {} Err(e) => return Err(e), } } Ok(()) } /// Reads data from the descriptor chain buffer into an `AsyncDisk` at offset `off`. /// Returns the number of bytes read from the descriptor chain buffer. /// The number of bytes read can be less than `count` if there isn't /// enough data in the descriptor chain buffer. pub async fn read_to_at_fut( &mut self, dst: &F, count: usize, off: u64, ) -> disk::Result { let written = dst .write_from_mem( off, Arc::new(self.mem.clone()), self.regions.get_remaining_regions_with_count(count), ) .await?; self.regions.consume(written); Ok(written) } /// Reads exactly `count` bytes from the chain to the disk asynchronously or returns an error if /// not enough data can be read. pub async fn read_exact_to_at_fut( &mut self, dst: &F, mut count: usize, mut off: u64, ) -> disk::Result<()> { while count > 0 { let nread = self.read_to_at_fut(dst, count, off).await?; if nread == 0 { return Err(disk::Error::ReadingData(io::Error::new( io::ErrorKind::UnexpectedEof, "failed to write whole buffer", ))); } count -= nread; off += nread as u64; } Ok(()) } /// Returns number of bytes available for reading. May return an error if the combined /// lengths of all the buffers in the DescriptorChain would cause an integer overflow. pub fn available_bytes(&self) -> usize { self.regions.available_bytes() } /// Returns number of bytes already read from the descriptor chain buffer. pub fn bytes_read(&self) -> usize { self.regions.bytes_consumed() } pub fn get_remaining_regions(&self) -> MemRegionIter { self.regions.get_remaining_regions() } /// Returns a `&[VolatileSlice]` that represents all the remaining data in this `Reader`. /// Calling this method does not actually consume any data from the `Reader` and callers should /// call `consume` to advance the `Reader`. pub fn get_remaining(&self) -> SmallVec<[VolatileSlice; 16]> { self.regions.get_remaining(&self.mem) } /// Consumes `amt` bytes from the underlying descriptor chain. If `amt` is larger than the /// remaining data left in this `Reader`, then all remaining data will be consumed. pub fn consume(&mut self, amt: usize) { self.regions.consume(amt) } /// Splits this `Reader` into two at the given offset in the `DescriptorChain` buffer. After the /// split, `self` will be able to read up to `offset` bytes while the returned `Reader` can read /// up to `available_bytes() - offset` bytes. If `offset > self.available_bytes()`, then the /// returned `Reader` will not be able to read any bytes. pub fn split_at(&mut self, offset: usize) -> Reader { Reader { mem: self.mem.clone(), regions: self.regions.split_at(offset), } } } /// Copy up to `size` bytes from `src` into `dst`. /// /// Returns the total number of bytes copied. /// /// # Safety /// /// The caller must ensure that it is safe to write `size` bytes of data into `dst`. /// /// After the function returns, it is only safe to assume that the number of bytes indicated by the /// return value (which may be less than the requested `size`) have been initialized. Bytes beyond /// that point are not initialized by this function. unsafe fn copy_regions_to_mut_ptr( mem: &GuestMemory, src: MemRegionIter, dst: *mut u8, size: usize, ) -> io::Result { let mut copied = 0; for src_region in src { if copied >= size { break; } let remaining = size - copied; let count = cmp::min(remaining, src_region.len); let vslice = mem .get_slice_at_addr(GuestAddress(src_region.offset), count) .map_err(|_e| io::Error::from(io::ErrorKind::InvalidData))?; // SAFETY: `get_slice_at_addr()` verified that the region points to valid memory, and // the `count` calculation ensures we will write at most `size` bytes into `dst`. unsafe { copy_nonoverlapping(vslice.as_ptr(), dst.add(copied), count); } copied += count; } Ok(copied) } impl io::Read for Reader { fn read(&mut self, buf: &mut [u8]) -> io::Result { // SAFETY: We pass a valid pointer and size combination derived from `buf`. let total = unsafe { copy_regions_to_mut_ptr( &self.mem, self.regions.get_remaining_regions(), buf.as_mut_ptr(), buf.len(), )? }; self.regions.consume(total); Ok(total) } } /// Provides high-level interface over the sequence of memory regions /// defined by writable descriptors in the descriptor chain. /// /// Note that virtio spec requires driver to place any device-writable /// descriptors after any device-readable descriptors (2.6.4.2 in Virtio Spec v1.1). /// Writer will start iterating the descriptors from the first writable one and will /// assume that all following descriptors are writable. pub struct Writer { mem: GuestMemory, regions: DescriptorChainRegions, } impl Writer { /// Construct a new Writer wrapper over `writable_regions`. pub fn new_from_regions( mem: &GuestMemory, writable_regions: SmallVec<[MemRegion; 2]>, ) -> Writer { Writer { mem: mem.clone(), regions: DescriptorChainRegions::new(writable_regions), } } /// Writes an object to the descriptor chain buffer. pub fn write_obj(&mut self, val: T) -> io::Result<()> { self.write_all(val.as_bytes()) } /// Writes all objects produced by `iter` into the descriptor chain buffer. Unlike `consume`, /// this doesn't require the values to be stored in an intermediate collection first. It also /// allows callers to choose which elements in a collection to write, for example by using the /// `filter` or `take` methods of the `Iterator` trait. pub fn write_iter>(&mut self, mut iter: I) -> io::Result<()> { iter.try_for_each(|v| self.write_obj(v)) } /// Writes a collection of objects into the descriptor chain buffer. pub fn consume>(&mut self, vals: C) -> io::Result<()> { self.write_iter(vals.into_iter()) } /// Returns number of bytes available for writing. May return an error if the combined /// lengths of all the buffers in the DescriptorChain would cause an overflow. pub fn available_bytes(&self) -> usize { self.regions.available_bytes() } /// Reads data into a volatile slice up to the minimum of the slice's length or the number of /// bytes remaining. Returns the number of bytes read. pub fn write_from_volatile_slice(&mut self, slice: VolatileSlice) -> usize { let mut written = 0usize; let mut src = slice; for dst in self.get_remaining() { src.copy_to_volatile_slice(dst); let copied = std::cmp::min(src.size(), dst.size()); written += copied; src = match src.offset(copied) { Ok(v) => v, Err(_) => break, // The slice is fully consumed }; } self.regions.consume(written); written } /// Writes data to the descriptor chain buffer from a readable object. /// Returns the number of bytes written to the descriptor chain buffer. /// The number of bytes written can be less than `count` if /// there isn't enough data in the descriptor chain buffer. pub fn write_from( &mut self, mut src: F, count: usize, ) -> io::Result { let iovs = self.regions.get_remaining_with_count(&self.mem, count); let read = src.read_vectored_volatile(&iovs[..])?; self.regions.consume(read); Ok(read) } /// Writes data to the descriptor chain buffer from a File at offset `off`. /// Returns the number of bytes written to the descriptor chain buffer. /// The number of bytes written can be less than `count` if /// there isn't enough data in the descriptor chain buffer. pub fn write_from_at( &mut self, src: &F, count: usize, off: u64, ) -> io::Result { let iovs = self.regions.get_remaining_with_count(&self.mem, count); let read = src.read_vectored_at_volatile(&iovs[..], off)?; self.regions.consume(read); Ok(read) } pub fn write_all_from( &mut self, mut src: F, mut count: usize, ) -> io::Result<()> { while count > 0 { match self.write_from(&mut src, count) { Ok(0) => { return Err(io::Error::new( io::ErrorKind::WriteZero, "failed to write whole buffer", )) } Ok(n) => count -= n, Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {} Err(e) => return Err(e), } } Ok(()) } pub fn write_all_from_at( &mut self, src: &F, mut count: usize, mut off: u64, ) -> io::Result<()> { while count > 0 { match self.write_from_at(src, count, off) { Ok(0) => { return Err(io::Error::new( io::ErrorKind::WriteZero, "failed to write whole buffer", )) } Ok(n) => { count -= n; off += n as u64; } Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {} Err(e) => return Err(e), } } Ok(()) } /// Writes data to the descriptor chain buffer from an `AsyncDisk` at offset `off`. /// Returns the number of bytes written to the descriptor chain buffer. /// The number of bytes written can be less than `count` if /// there isn't enough data in the descriptor chain buffer. pub async fn write_from_at_fut( &mut self, src: &F, count: usize, off: u64, ) -> disk::Result { let read = src .read_to_mem( off, Arc::new(self.mem.clone()), self.regions.get_remaining_regions_with_count(count), ) .await?; self.regions.consume(read); Ok(read) } pub async fn write_all_from_at_fut( &mut self, src: &F, mut count: usize, mut off: u64, ) -> disk::Result<()> { while count > 0 { let nwritten = self.write_from_at_fut(src, count, off).await?; if nwritten == 0 { return Err(disk::Error::WritingData(io::Error::new( io::ErrorKind::UnexpectedEof, "failed to write whole buffer", ))); } count -= nwritten; off += nwritten as u64; } Ok(()) } /// Returns number of bytes already written to the descriptor chain buffer. pub fn bytes_written(&self) -> usize { self.regions.bytes_consumed() } pub fn get_remaining_regions(&self) -> MemRegionIter { self.regions.get_remaining_regions() } /// Returns a `&[VolatileSlice]` that represents all the remaining data in this `Writer`. /// Calling this method does not actually advance the current position of the `Writer` in the /// buffer and callers should call `consume_bytes` to advance the `Writer`. Not calling /// `consume_bytes` with the amount of data copied into the returned `VolatileSlice`s will /// result in that that data being overwritten the next time data is written into the `Writer`. pub fn get_remaining(&self) -> SmallVec<[VolatileSlice; 16]> { self.regions.get_remaining(&self.mem) } /// Consumes `amt` bytes from the underlying descriptor chain. If `amt` is larger than the /// remaining data left in this `Reader`, then all remaining data will be consumed. pub fn consume_bytes(&mut self, amt: usize) { self.regions.consume(amt) } /// Splits this `Writer` into two at the given offset in the `DescriptorChain` buffer. After the /// split, `self` will be able to write up to `offset` bytes while the returned `Writer` can /// write up to `available_bytes() - offset` bytes. If `offset > self.available_bytes()`, then /// the returned `Writer` will not be able to write any bytes. pub fn split_at(&mut self, offset: usize) -> Writer { Writer { mem: self.mem.clone(), regions: self.regions.split_at(offset), } } } impl io::Write for Writer { fn write(&mut self, buf: &[u8]) -> io::Result { let mut rem = buf; let mut total = 0; for b in self.regions.get_remaining(&self.mem) { if rem.is_empty() { break; } let count = cmp::min(rem.len(), b.size()); // SAFETY: // Safe because we have already verified that `vs` points to valid memory. unsafe { copy_nonoverlapping(rem.as_ptr(), b.as_mut_ptr(), count); } rem = &rem[count..]; total += count; } self.regions.consume(total); Ok(total) } fn flush(&mut self) -> io::Result<()> { // Nothing to flush since the writes go straight into the buffer. Ok(()) } } const VIRTQ_DESC_F_NEXT: u16 = 0x1; const VIRTQ_DESC_F_WRITE: u16 = 0x2; #[derive(Copy, Clone, PartialEq, Eq)] pub enum DescriptorType { Readable, Writable, } #[derive(Copy, Clone, Debug, FromZeroes, FromBytes, AsBytes)] #[repr(C)] struct virtq_desc { addr: Le64, len: Le32, flags: Le16, next: Le16, } /// Test utility function to create a descriptor chain in guest memory. pub fn create_descriptor_chain( memory: &GuestMemory, descriptor_array_addr: GuestAddress, mut buffers_start_addr: GuestAddress, descriptors: Vec<(DescriptorType, u32)>, spaces_between_regions: u32, ) -> anyhow::Result { let descriptors_len = descriptors.len(); for (index, (type_, size)) in descriptors.into_iter().enumerate() { let mut flags = 0; if let DescriptorType::Writable = type_ { flags |= VIRTQ_DESC_F_WRITE; } if index + 1 < descriptors_len { flags |= VIRTQ_DESC_F_NEXT; } let index = index as u16; let desc = virtq_desc { addr: buffers_start_addr.offset().into(), len: size.into(), flags: flags.into(), next: (index + 1).into(), }; let offset = size + spaces_between_regions; buffers_start_addr = buffers_start_addr .checked_add(offset as u64) .context("Invalid buffers_start_addr)")?; let _ = memory.write_obj_at_addr( desc, descriptor_array_addr .checked_add(index as u64 * std::mem::size_of::() as u64) .context("Invalid descriptor_array_addr")?, ); } let chain = SplitDescriptorChain::new(memory, descriptor_array_addr, 0x100, 0); DescriptorChain::new(chain, memory, 0) } #[cfg(test)] mod tests { use std::fs::File; use std::io::Read; use cros_async::Executor; use tempfile::tempfile; use tempfile::NamedTempFile; use super::*; #[test] fn reader_test_simple_chain() { use DescriptorType::*; let memory_start_addr = GuestAddress(0x0); let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap(); let mut chain = create_descriptor_chain( &memory, GuestAddress(0x0), GuestAddress(0x100), vec![ (Readable, 8), (Readable, 16), (Readable, 18), (Readable, 64), ], 0, ) .expect("create_descriptor_chain failed"); let reader = &mut chain.reader; assert_eq!(reader.available_bytes(), 106); assert_eq!(reader.bytes_read(), 0); let mut buffer = [0u8; 64]; reader .read_exact(&mut buffer) .expect("read_exact should not fail here"); assert_eq!(reader.available_bytes(), 42); assert_eq!(reader.bytes_read(), 64); match reader.read(&mut buffer) { Err(_) => panic!("read should not fail here"), Ok(length) => assert_eq!(length, 42), } assert_eq!(reader.available_bytes(), 0); assert_eq!(reader.bytes_read(), 106); } #[test] fn writer_test_simple_chain() { use DescriptorType::*; let memory_start_addr = GuestAddress(0x0); let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap(); let mut chain = create_descriptor_chain( &memory, GuestAddress(0x0), GuestAddress(0x100), vec![ (Writable, 8), (Writable, 16), (Writable, 18), (Writable, 64), ], 0, ) .expect("create_descriptor_chain failed"); let writer = &mut chain.writer; assert_eq!(writer.available_bytes(), 106); assert_eq!(writer.bytes_written(), 0); let buffer = [0; 64]; writer .write_all(&buffer) .expect("write_all should not fail here"); assert_eq!(writer.available_bytes(), 42); assert_eq!(writer.bytes_written(), 64); match writer.write(&buffer) { Err(_) => panic!("write should not fail here"), Ok(length) => assert_eq!(length, 42), } assert_eq!(writer.available_bytes(), 0); assert_eq!(writer.bytes_written(), 106); } #[test] fn reader_test_incompatible_chain() { use DescriptorType::*; let memory_start_addr = GuestAddress(0x0); let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap(); let mut chain = create_descriptor_chain( &memory, GuestAddress(0x0), GuestAddress(0x100), vec![(Writable, 8)], 0, ) .expect("create_descriptor_chain failed"); let reader = &mut chain.reader; assert_eq!(reader.available_bytes(), 0); assert_eq!(reader.bytes_read(), 0); assert!(reader.read_obj::().is_err()); assert_eq!(reader.available_bytes(), 0); assert_eq!(reader.bytes_read(), 0); } #[test] fn writer_test_incompatible_chain() { use DescriptorType::*; let memory_start_addr = GuestAddress(0x0); let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap(); let mut chain = create_descriptor_chain( &memory, GuestAddress(0x0), GuestAddress(0x100), vec![(Readable, 8)], 0, ) .expect("create_descriptor_chain failed"); let writer = &mut chain.writer; assert_eq!(writer.available_bytes(), 0); assert_eq!(writer.bytes_written(), 0); assert!(writer.write_obj(0u8).is_err()); assert_eq!(writer.available_bytes(), 0); assert_eq!(writer.bytes_written(), 0); } #[test] fn reader_failing_io() { use DescriptorType::*; let memory_start_addr = GuestAddress(0x0); let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap(); let mut chain = create_descriptor_chain( &memory, GuestAddress(0x0), GuestAddress(0x100), vec![(Readable, 256), (Readable, 256)], 0, ) .expect("create_descriptor_chain failed"); let reader = &mut chain.reader; // Open a file in read-only mode so writes to it to trigger an I/O error. let device_file = if cfg!(windows) { "NUL" } else { "/dev/zero" }; let mut ro_file = File::open(device_file).expect("failed to open device file"); reader .read_exact_to(&mut ro_file, 512) .expect_err("successfully read more bytes than SharedMemory size"); // The write above should have failed entirely, so we end up not writing any bytes at all. assert_eq!(reader.available_bytes(), 512); assert_eq!(reader.bytes_read(), 0); } #[test] fn writer_failing_io() { use DescriptorType::*; let memory_start_addr = GuestAddress(0x0); let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap(); let mut chain = create_descriptor_chain( &memory, GuestAddress(0x0), GuestAddress(0x100), vec![(Writable, 256), (Writable, 256)], 0, ) .expect("create_descriptor_chain failed"); let writer = &mut chain.writer; let mut file = tempfile().unwrap(); file.set_len(384).unwrap(); writer .write_all_from(&mut file, 512) .expect_err("successfully wrote more bytes than in SharedMemory"); assert_eq!(writer.available_bytes(), 128); assert_eq!(writer.bytes_written(), 384); } #[test] fn reader_writer_shared_chain() { use DescriptorType::*; let memory_start_addr = GuestAddress(0x0); let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap(); let mut chain = create_descriptor_chain( &memory, GuestAddress(0x0), GuestAddress(0x100), vec![ (Readable, 16), (Readable, 16), (Readable, 96), (Writable, 64), (Writable, 1), (Writable, 3), ], 0, ) .expect("create_descriptor_chain failed"); let reader = &mut chain.reader; let writer = &mut chain.writer; assert_eq!(reader.bytes_read(), 0); assert_eq!(writer.bytes_written(), 0); let mut buffer = Vec::with_capacity(200); assert_eq!( reader .read_to_end(&mut buffer) .expect("read should not fail here"), 128 ); // The writable descriptors are only 68 bytes long. writer .write_all(&buffer[..68]) .expect("write should not fail here"); assert_eq!(reader.available_bytes(), 0); assert_eq!(reader.bytes_read(), 128); assert_eq!(writer.available_bytes(), 0); assert_eq!(writer.bytes_written(), 68); } #[test] fn reader_writer_shattered_object() { use DescriptorType::*; let memory_start_addr = GuestAddress(0x0); let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap(); let secret: Le32 = 0x12345678.into(); // Create a descriptor chain with memory regions that are properly separated. let mut chain_writer = create_descriptor_chain( &memory, GuestAddress(0x0), GuestAddress(0x100), vec![(Writable, 1), (Writable, 1), (Writable, 1), (Writable, 1)], 123, ) .expect("create_descriptor_chain failed"); let writer = &mut chain_writer.writer; writer .write_obj(secret) .expect("write_obj should not fail here"); // Now create new descriptor chain pointing to the same memory and try to read it. let mut chain_reader = create_descriptor_chain( &memory, GuestAddress(0x0), GuestAddress(0x100), vec![(Readable, 1), (Readable, 1), (Readable, 1), (Readable, 1)], 123, ) .expect("create_descriptor_chain failed"); let reader = &mut chain_reader.reader; match reader.read_obj::() { Err(_) => panic!("read_obj should not fail here"), Ok(read_secret) => assert_eq!(read_secret, secret), } } #[test] fn reader_unexpected_eof() { use DescriptorType::*; let memory_start_addr = GuestAddress(0x0); let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap(); let mut chain = create_descriptor_chain( &memory, GuestAddress(0x0), GuestAddress(0x100), vec![(Readable, 256), (Readable, 256)], 0, ) .expect("create_descriptor_chain failed"); let reader = &mut chain.reader; let mut buf = vec![0; 1024]; assert_eq!( reader .read_exact(&mut buf[..]) .expect_err("read more bytes than available") .kind(), io::ErrorKind::UnexpectedEof ); } #[test] fn split_border() { use DescriptorType::*; let memory_start_addr = GuestAddress(0x0); let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap(); let mut chain = create_descriptor_chain( &memory, GuestAddress(0x0), GuestAddress(0x100), vec![ (Readable, 16), (Readable, 16), (Readable, 96), (Writable, 64), (Writable, 1), (Writable, 3), ], 0, ) .expect("create_descriptor_chain failed"); let reader = &mut chain.reader; let other = reader.split_at(32); assert_eq!(reader.available_bytes(), 32); assert_eq!(other.available_bytes(), 96); } #[test] fn split_middle() { use DescriptorType::*; let memory_start_addr = GuestAddress(0x0); let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap(); let mut chain = create_descriptor_chain( &memory, GuestAddress(0x0), GuestAddress(0x100), vec![ (Readable, 16), (Readable, 16), (Readable, 96), (Writable, 64), (Writable, 1), (Writable, 3), ], 0, ) .expect("create_descriptor_chain failed"); let reader = &mut chain.reader; let other = reader.split_at(24); assert_eq!(reader.available_bytes(), 24); assert_eq!(other.available_bytes(), 104); } #[test] fn split_end() { use DescriptorType::*; let memory_start_addr = GuestAddress(0x0); let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap(); let mut chain = create_descriptor_chain( &memory, GuestAddress(0x0), GuestAddress(0x100), vec![ (Readable, 16), (Readable, 16), (Readable, 96), (Writable, 64), (Writable, 1), (Writable, 3), ], 0, ) .expect("create_descriptor_chain failed"); let reader = &mut chain.reader; let other = reader.split_at(128); assert_eq!(reader.available_bytes(), 128); assert_eq!(other.available_bytes(), 0); } #[test] fn split_beginning() { use DescriptorType::*; let memory_start_addr = GuestAddress(0x0); let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap(); let mut chain = create_descriptor_chain( &memory, GuestAddress(0x0), GuestAddress(0x100), vec![ (Readable, 16), (Readable, 16), (Readable, 96), (Writable, 64), (Writable, 1), (Writable, 3), ], 0, ) .expect("create_descriptor_chain failed"); let reader = &mut chain.reader; let other = reader.split_at(0); assert_eq!(reader.available_bytes(), 0); assert_eq!(other.available_bytes(), 128); } #[test] fn split_outofbounds() { use DescriptorType::*; let memory_start_addr = GuestAddress(0x0); let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap(); let mut chain = create_descriptor_chain( &memory, GuestAddress(0x0), GuestAddress(0x100), vec![ (Readable, 16), (Readable, 16), (Readable, 96), (Writable, 64), (Writable, 1), (Writable, 3), ], 0, ) .expect("create_descriptor_chain failed"); let reader = &mut chain.reader; let other = reader.split_at(256); assert_eq!( other.available_bytes(), 0, "Reader returned from out-of-bounds split still has available bytes" ); } #[test] fn read_full() { use DescriptorType::*; let memory_start_addr = GuestAddress(0x0); let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap(); let mut chain = create_descriptor_chain( &memory, GuestAddress(0x0), GuestAddress(0x100), vec![(Readable, 16), (Readable, 16), (Readable, 16)], 0, ) .expect("create_descriptor_chain failed"); let reader = &mut chain.reader; let mut buf = [0u8; 64]; assert_eq!( reader.read(&mut buf[..]).expect("failed to read to buffer"), 48 ); } #[test] fn write_full() { use DescriptorType::*; let memory_start_addr = GuestAddress(0x0); let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap(); let mut chain = create_descriptor_chain( &memory, GuestAddress(0x0), GuestAddress(0x100), vec![(Writable, 16), (Writable, 16), (Writable, 16)], 0, ) .expect("create_descriptor_chain failed"); let writer = &mut chain.writer; let buf = [0xdeu8; 64]; assert_eq!( writer.write(&buf[..]).expect("failed to write from buffer"), 48 ); } #[test] fn consume_collect() { use DescriptorType::*; let memory_start_addr = GuestAddress(0x0); let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap(); let vs: Vec = vec![ 0x0101010101010101.into(), 0x0202020202020202.into(), 0x0303030303030303.into(), ]; let mut write_chain = create_descriptor_chain( &memory, GuestAddress(0x0), GuestAddress(0x100), vec![(Writable, 24)], 0, ) .expect("create_descriptor_chain failed"); let writer = &mut write_chain.writer; writer .consume(vs.clone()) .expect("failed to consume() a vector"); let mut read_chain = create_descriptor_chain( &memory, GuestAddress(0x0), GuestAddress(0x100), vec![(Readable, 24)], 0, ) .expect("create_descriptor_chain failed"); let reader = &mut read_chain.reader; let vs_read = reader .collect::>, _>() .expect("failed to collect() values"); assert_eq!(vs, vs_read); } #[test] fn get_remaining_region_with_count() { use DescriptorType::*; let memory_start_addr = GuestAddress(0x0); let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap(); let chain = create_descriptor_chain( &memory, GuestAddress(0x0), GuestAddress(0x100), vec![ (Readable, 16), (Readable, 16), (Readable, 96), (Writable, 64), (Writable, 1), (Writable, 3), ], 0, ) .expect("create_descriptor_chain failed"); let Reader { mem: _, mut regions, } = chain.reader; let drain = regions .get_remaining_regions_with_count(usize::MAX) .fold(0usize, |total, region| total + region.len); assert_eq!(drain, 128); let exact = regions .get_remaining_regions_with_count(32) .fold(0usize, |total, region| total + region.len); assert!(exact > 0); assert!(exact <= 32); let split = regions .get_remaining_regions_with_count(24) .fold(0usize, |total, region| total + region.len); assert!(split > 0); assert!(split <= 24); regions.consume(64); let first = regions .get_remaining_regions_with_count(8) .fold(0usize, |total, region| total + region.len); assert!(first > 0); assert!(first <= 8); } #[test] fn get_remaining_with_count() { use DescriptorType::*; let memory_start_addr = GuestAddress(0x0); let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap(); let chain = create_descriptor_chain( &memory, GuestAddress(0x0), GuestAddress(0x100), vec![ (Readable, 16), (Readable, 16), (Readable, 96), (Writable, 64), (Writable, 1), (Writable, 3), ], 0, ) .expect("create_descriptor_chain failed"); let Reader { mem: _, mut regions, } = chain.reader; let drain = regions .get_remaining_with_count(&memory, usize::MAX) .iter() .fold(0usize, |total, iov| total + iov.size()); assert_eq!(drain, 128); let exact = regions .get_remaining_with_count(&memory, 32) .iter() .fold(0usize, |total, iov| total + iov.size()); assert!(exact > 0); assert!(exact <= 32); let split = regions .get_remaining_with_count(&memory, 24) .iter() .fold(0usize, |total, iov| total + iov.size()); assert!(split > 0); assert!(split <= 24); regions.consume(64); let first = regions .get_remaining_with_count(&memory, 8) .iter() .fold(0usize, |total, iov| total + iov.size()); assert!(first > 0); assert!(first <= 8); } #[test] fn reader_peek_obj() { use DescriptorType::*; let memory_start_addr = GuestAddress(0x0); let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap(); // Write test data to memory. memory .write_obj_at_addr(Le16::from(0xBEEF), GuestAddress(0x100)) .unwrap(); memory .write_obj_at_addr(Le16::from(0xDEAD), GuestAddress(0x200)) .unwrap(); let mut chain_reader = create_descriptor_chain( &memory, GuestAddress(0x0), GuestAddress(0x100), vec![(Readable, 2), (Readable, 2)], 0x100 - 2, ) .expect("create_descriptor_chain failed"); let reader = &mut chain_reader.reader; // peek_obj() at the beginning of the chain should return the first object. let peek1 = reader.peek_obj::().unwrap(); assert_eq!(peek1, Le16::from(0xBEEF)); // peek_obj() again should return the same object, since it was not consumed. let peek2 = reader.peek_obj::().unwrap(); assert_eq!(peek2, Le16::from(0xBEEF)); // peek_obj() of an object spanning two descriptors should copy from both. let peek3 = reader.peek_obj::().unwrap(); assert_eq!(peek3, Le32::from(0xDEADBEEF)); // read_obj() should return the first object. let read1 = reader.read_obj::().unwrap(); assert_eq!(read1, Le16::from(0xBEEF)); // peek_obj() of a value that is larger than the rest of the chain should fail. reader .peek_obj::() .expect_err("peek_obj past end of chain"); // read_obj() again should return the second object. let read2 = reader.read_obj::().unwrap(); assert_eq!(read2, Le16::from(0xDEAD)); // peek_obj() should fail at the end of the chain. reader .peek_obj::() .expect_err("peek_obj past end of chain"); } #[test] fn region_reader_failing_io() { let ex = Executor::new().unwrap(); ex.run_until(region_reader_failing_io_async(&ex)).unwrap(); } async fn region_reader_failing_io_async(ex: &Executor) { use DescriptorType::*; let memory_start_addr = GuestAddress(0x0); let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap(); let mut chain = create_descriptor_chain( &memory, GuestAddress(0x0), GuestAddress(0x100), vec![(Readable, 256), (Readable, 256)], 0, ) .expect("create_descriptor_chain failed"); let reader = &mut chain.reader; // Open a file in read-only mode so writes to it to trigger an I/O error. let named_temp_file = NamedTempFile::new().expect("failed to create temp file"); let ro_file = File::open(named_temp_file.path()).expect("failed to open temp file read only"); let async_ro_file = disk::SingleFileDisk::new(ro_file, ex).expect("Failed to crate SFD"); reader .read_exact_to_at_fut(&async_ro_file, 512, 0) .await .expect_err("successfully read more bytes than SingleFileDisk size"); // The write above should have failed entirely, so we end up not writing any bytes at all. assert_eq!(reader.available_bytes(), 512); assert_eq!(reader.bytes_read(), 0); } #[test] fn region_writer_failing_io() { let ex = Executor::new().unwrap(); ex.run_until(region_writer_failing_io_async(&ex)).unwrap() } async fn region_writer_failing_io_async(ex: &Executor) { use DescriptorType::*; let memory_start_addr = GuestAddress(0x0); let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap(); let mut chain = create_descriptor_chain( &memory, GuestAddress(0x0), GuestAddress(0x100), vec![(Writable, 256), (Writable, 256)], 0, ) .expect("create_descriptor_chain failed"); let writer = &mut chain.writer; let file = tempfile().expect("failed to create temp file"); file.set_len(384).unwrap(); let async_file = disk::SingleFileDisk::new(file, ex).expect("Failed to crate SFD"); writer .write_all_from_at_fut(&async_file, 512, 0) .await .expect_err("successfully wrote more bytes than in SingleFileDisk"); assert_eq!(writer.available_bytes(), 128); assert_eq!(writer.bytes_written(), 384); } }