1 // Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // SPDX-License-Identifier: Apache-2.0 OR BSD-3-Clause
3 
4 //! Utilities used by unit tests and benchmarks for mocking the driver side
5 //! of the virtio protocol.
6 
7 use std::marker::PhantomData;
8 use std::mem::size_of;
9 
10 use vm_memory::{
11     Address, ByteValued, Bytes, GuestAddress, GuestMemory, GuestMemoryError, GuestUsize,
12 };
13 
14 use crate::defs::{VIRTQ_AVAIL_ELEMENT_SIZE, VIRTQ_AVAIL_RING_HEADER_SIZE};
15 use crate::{Descriptor, DescriptorChain, Error, Queue, QueueOwnedT, QueueT, VirtqUsedElem};
16 use std::fmt::{self, Debug, Display};
17 use virtio_bindings::bindings::virtio_ring::{VRING_DESC_F_INDIRECT, VRING_DESC_F_NEXT};
18 
19 /// Mock related errors.
20 #[derive(Debug)]
21 pub enum MockError {
22     /// Cannot create the Queue object due to invalid parameters.
23     InvalidQueueParams(Error),
24     /// Invalid Ref index
25     InvalidIndex,
26     /// Invalid next avail
27     InvalidNextAvail,
28     /// Guest memory errors
29     GuestMem(GuestMemoryError),
30 }
31 
32 impl Display for MockError {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result33     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
34         use self::MockError::*;
35 
36         match self {
37             InvalidQueueParams(_) => write!(f, "cannot create queue due to invalid parameter"),
38             InvalidIndex => write!(
39                 f,
40                 "invalid index for pointing to an address in a region when defining a Ref object"
41             ),
42             InvalidNextAvail => write!(
43                 f,
44                 "invalid next available descriptor chain head in the queue"
45             ),
46             GuestMem(e) => write!(f, "guest memory error: {}", e),
47         }
48     }
49 }
50 
51 impl std::error::Error for MockError {}
52 
53 /// Wrapper struct used for accessing a particular address of a GuestMemory area.
54 pub struct Ref<'a, M, T> {
55     mem: &'a M,
56     addr: GuestAddress,
57     phantom: PhantomData<*const T>,
58 }
59 
60 impl<'a, M: GuestMemory, T: ByteValued> Ref<'a, M, T> {
new(mem: &'a M, addr: GuestAddress) -> Self61     fn new(mem: &'a M, addr: GuestAddress) -> Self {
62         Ref {
63             mem,
64             addr,
65             phantom: PhantomData,
66         }
67     }
68 
69     /// Read an object of type T from the underlying memory found at self.addr.
load(&self) -> T70     pub fn load(&self) -> T {
71         self.mem.read_obj(self.addr).unwrap()
72     }
73 
74     /// Write an object of type T from the underlying memory found at self.addr.
store(&self, val: T)75     pub fn store(&self, val: T) {
76         self.mem.write_obj(val, self.addr).unwrap()
77     }
78 }
79 
80 /// Wrapper struct used for accessing a subregion of a GuestMemory area.
81 pub struct ArrayRef<'a, M, T> {
82     mem: &'a M,
83     addr: GuestAddress,
84     len: usize,
85     phantom: PhantomData<*const T>,
86 }
87 
88 impl<'a, M: GuestMemory, T: ByteValued> ArrayRef<'a, M, T> {
new(mem: &'a M, addr: GuestAddress, len: usize) -> Self89     fn new(mem: &'a M, addr: GuestAddress, len: usize) -> Self {
90         ArrayRef {
91             mem,
92             addr,
93             len,
94             phantom: PhantomData,
95         }
96     }
97 
98     /// Return a `Ref` object pointing to an address defined by a particular
99     /// index offset in the region.
ref_at(&self, index: usize) -> Result<Ref<'a, M, T>, MockError>100     pub fn ref_at(&self, index: usize) -> Result<Ref<'a, M, T>, MockError> {
101         if index >= self.len {
102             return Err(MockError::InvalidIndex);
103         }
104 
105         let addr = self
106             .addr
107             .checked_add((index * size_of::<T>()) as u64)
108             .unwrap();
109 
110         Ok(Ref::new(self.mem, addr))
111     }
112 }
113 
114 /// Represents a virtio queue ring. The only difference between the used and available rings,
115 /// is the ring element type.
116 pub struct SplitQueueRing<'a, M, T: ByteValued> {
117     flags: Ref<'a, M, u16>,
118     // The value stored here should more precisely be a `Wrapping<u16>`, but that would require a
119     // `ByteValued` impl for this type, which is not provided in vm-memory. Implementing the trait
120     // here would require defining a wrapper for `Wrapping<u16>` and that would be too much for a
121     // mock framework that is only used in tests.
122     idx: Ref<'a, M, u16>,
123     ring: ArrayRef<'a, M, T>,
124     // `used_event` for `AvailRing`, `avail_event` for `UsedRing`.
125     event: Ref<'a, M, u16>,
126 }
127 
128 impl<'a, M: GuestMemory, T: ByteValued> SplitQueueRing<'a, M, T> {
129     /// Create a new `SplitQueueRing` instance
new(mem: &'a M, base: GuestAddress, len: u16) -> Self130     pub fn new(mem: &'a M, base: GuestAddress, len: u16) -> Self {
131         let event_addr = base
132             .checked_add(4)
133             .and_then(|a| a.checked_add((size_of::<u16>() * len as usize) as u64))
134             .unwrap();
135 
136         let split_queue_ring = SplitQueueRing {
137             flags: Ref::new(mem, base),
138             idx: Ref::new(mem, base.checked_add(2).unwrap()),
139             ring: ArrayRef::new(mem, base.checked_add(4).unwrap(), len as usize),
140             event: Ref::new(mem, event_addr),
141         };
142 
143         split_queue_ring.flags.store(0);
144         split_queue_ring.idx.store(0);
145         split_queue_ring.event.store(0);
146 
147         split_queue_ring
148     }
149 
150     /// Return the starting address of the `SplitQueueRing`.
start(&self) -> GuestAddress151     pub fn start(&self) -> GuestAddress {
152         self.ring.addr
153     }
154 
155     /// Return the end address of the `SplitQueueRing`.
end(&self) -> GuestAddress156     pub fn end(&self) -> GuestAddress {
157         self.start()
158             .checked_add(self.ring.len as GuestUsize)
159             .unwrap()
160     }
161 
162     /// Return a reference to the idx field.
idx(&self) -> &Ref<'a, M, u16>163     pub fn idx(&self) -> &Ref<'a, M, u16> {
164         &self.idx
165     }
166 
167     /// Return a reference to the ring field.
ring(&self) -> &ArrayRef<'a, M, T>168     pub fn ring(&self) -> &ArrayRef<'a, M, T> {
169         &self.ring
170     }
171 }
172 
173 /// The available ring is used by the driver to offer buffers to the device.
174 pub type AvailRing<'a, M> = SplitQueueRing<'a, M, u16>;
175 /// The used ring is where the device returns buffers once it is done with them.
176 pub type UsedRing<'a, M> = SplitQueueRing<'a, M, VirtqUsedElem>;
177 
178 /// Refers to the buffers the driver is using for the device.
179 pub struct DescriptorTable<'a, M> {
180     table: ArrayRef<'a, M, Descriptor>,
181     len: u16,
182     free_descriptors: Vec<u16>,
183 }
184 
185 impl<'a, M: GuestMemory> DescriptorTable<'a, M> {
186     /// Create a new `DescriptorTable` instance
new(mem: &'a M, addr: GuestAddress, len: u16) -> Self187     pub fn new(mem: &'a M, addr: GuestAddress, len: u16) -> Self {
188         let table = ArrayRef::new(mem, addr, len as usize);
189         let free_descriptors = (0..len).rev().collect();
190 
191         DescriptorTable {
192             table,
193             len,
194             free_descriptors,
195         }
196     }
197 
198     /// Read one descriptor from the specified index.
load(&self, index: u16) -> Result<Descriptor, MockError>199     pub fn load(&self, index: u16) -> Result<Descriptor, MockError> {
200         self.table
201             .ref_at(index as usize)
202             .map(|load_ref| load_ref.load())
203     }
204 
205     /// Write one descriptor at the specified index.
store(&self, index: u16, value: Descriptor) -> Result<(), MockError>206     pub fn store(&self, index: u16, value: Descriptor) -> Result<(), MockError> {
207         self.table
208             .ref_at(index as usize)
209             .map(|store_ref| store_ref.store(value))
210     }
211 
212     /// Return the total size of the DescriptorTable in bytes.
total_size(&self) -> u64213     pub fn total_size(&self) -> u64 {
214         (self.len as usize * size_of::<Descriptor>()) as u64
215     }
216 
217     /// Create a chain of descriptors.
build_chain(&mut self, len: u16) -> Result<u16, MockError>218     pub fn build_chain(&mut self, len: u16) -> Result<u16, MockError> {
219         let indices = self
220             .free_descriptors
221             .iter()
222             .copied()
223             .rev()
224             .take(usize::from(len))
225             .collect::<Vec<_>>();
226 
227         assert_eq!(indices.len(), len as usize);
228 
229         for (pos, index_value) in indices.iter().copied().enumerate() {
230             // Addresses and lens constant for now.
231             let mut desc = Descriptor::new(0x1000, 0x1000, 0, 0);
232 
233             // It's not the last descriptor in the chain.
234             if pos < indices.len() - 1 {
235                 desc.set_flags(VRING_DESC_F_NEXT as u16);
236                 desc.set_next(indices[pos + 1]);
237             } else {
238                 desc.set_flags(0);
239             }
240             self.store(index_value, desc)?;
241         }
242 
243         Ok(indices[0])
244     }
245 }
246 
247 trait GuestAddressExt {
align_up(&self, x: GuestUsize) -> GuestAddress248     fn align_up(&self, x: GuestUsize) -> GuestAddress;
249 }
250 
251 impl GuestAddressExt for GuestAddress {
align_up(&self, x: GuestUsize) -> GuestAddress252     fn align_up(&self, x: GuestUsize) -> GuestAddress {
253         Self((self.0 + (x - 1)) & !(x - 1))
254     }
255 }
256 
257 /// A mock version of the virtio queue implemented from the perspective of the driver.
258 pub struct MockSplitQueue<'a, M> {
259     mem: &'a M,
260     len: u16,
261     desc_table_addr: GuestAddress,
262     desc_table: DescriptorTable<'a, M>,
263     avail_addr: GuestAddress,
264     avail: AvailRing<'a, M>,
265     used_addr: GuestAddress,
266     used: UsedRing<'a, M>,
267     indirect_addr: GuestAddress,
268 }
269 
270 impl<'a, M: GuestMemory> MockSplitQueue<'a, M> {
271     /// Create a new `MockSplitQueue` instance with 0 as the default guest
272     /// physical starting address.
new(mem: &'a M, len: u16) -> Self273     pub fn new(mem: &'a M, len: u16) -> Self {
274         Self::create(mem, GuestAddress(0), len)
275     }
276 
277     /// Create a new `MockSplitQueue` instance.
create(mem: &'a M, start: GuestAddress, len: u16) -> Self278     pub fn create(mem: &'a M, start: GuestAddress, len: u16) -> Self {
279         const AVAIL_ALIGN: GuestUsize = 2;
280         const USED_ALIGN: GuestUsize = 4;
281 
282         let desc_table_addr = start;
283         let desc_table = DescriptorTable::new(mem, desc_table_addr, len);
284 
285         let avail_addr = start
286             .checked_add(16 * len as GuestUsize)
287             .unwrap()
288             .align_up(AVAIL_ALIGN);
289         let avail = AvailRing::new(mem, avail_addr, len);
290 
291         let used_addr = avail.end().align_up(USED_ALIGN);
292         let used = UsedRing::new(mem, used_addr, len);
293 
294         let indirect_addr = GuestAddress(0x3000_0000);
295 
296         MockSplitQueue {
297             mem,
298             len,
299             desc_table_addr,
300             desc_table,
301             avail_addr,
302             avail,
303             used_addr,
304             used,
305             indirect_addr,
306         }
307     }
308 
309     /// Return the starting address of the queue.
start(&self) -> GuestAddress310     pub fn start(&self) -> GuestAddress {
311         self.desc_table_addr
312     }
313 
314     /// Return the end address of the queue.
end(&self) -> GuestAddress315     pub fn end(&self) -> GuestAddress {
316         self.used.end()
317     }
318 
319     /// Descriptor table accessor.
desc_table(&self) -> &DescriptorTable<'a, M>320     pub fn desc_table(&self) -> &DescriptorTable<'a, M> {
321         &self.desc_table
322     }
323 
324     /// Available ring accessor.
avail(&self) -> &AvailRing<M>325     pub fn avail(&self) -> &AvailRing<M> {
326         &self.avail
327     }
328 
329     /// Used ring accessor.
used(&self) -> &UsedRing<M>330     pub fn used(&self) -> &UsedRing<M> {
331         &self.used
332     }
333 
334     /// Return the starting address of the descriptor table.
desc_table_addr(&self) -> GuestAddress335     pub fn desc_table_addr(&self) -> GuestAddress {
336         self.desc_table_addr
337     }
338 
339     /// Return the starting address of the available ring.
avail_addr(&self) -> GuestAddress340     pub fn avail_addr(&self) -> GuestAddress {
341         self.avail_addr
342     }
343 
344     /// Return the starting address of the used ring.
used_addr(&self) -> GuestAddress345     pub fn used_addr(&self) -> GuestAddress {
346         self.used_addr
347     }
348 
update_avail_idx(&mut self, value: u16) -> Result<(), MockError>349     fn update_avail_idx(&mut self, value: u16) -> Result<(), MockError> {
350         let avail_idx = self.avail.idx.load();
351         self.avail.ring.ref_at(avail_idx as usize)?.store(value);
352         self.avail.idx.store(avail_idx.wrapping_add(1));
353         Ok(())
354     }
355 
alloc_indirect_chain(&mut self, len: u16) -> Result<GuestAddress, MockError>356     fn alloc_indirect_chain(&mut self, len: u16) -> Result<GuestAddress, MockError> {
357         // To simplify things for now, we round up the table len as a multiple of 16. When this is
358         // no longer the case, we should make sure the starting address of the descriptor table
359         // we're  creating below is properly aligned.
360 
361         let table_len = if len % 16 == 0 {
362             len
363         } else {
364             16 * (len / 16 + 1)
365         };
366 
367         let mut table = DescriptorTable::new(self.mem, self.indirect_addr, table_len);
368         let head_decriptor_index = table.build_chain(len)?;
369         // When building indirect descriptor tables, the descriptor at index 0 is supposed to be
370         // first in the resulting chain. Just making sure our logic actually makes that happen.
371         assert_eq!(head_decriptor_index, 0);
372 
373         let table_addr = self.indirect_addr;
374         self.indirect_addr = self.indirect_addr.checked_add(table.total_size()).unwrap();
375         Ok(table_addr)
376     }
377 
378     /// Add a descriptor chain to the table.
add_chain(&mut self, len: u16) -> Result<(), MockError>379     pub fn add_chain(&mut self, len: u16) -> Result<(), MockError> {
380         self.desc_table
381             .build_chain(len)
382             .and_then(|head_idx| self.update_avail_idx(head_idx))
383     }
384 
385     /// Add an indirect descriptor chain to the table.
add_indirect_chain(&mut self, len: u16) -> Result<(), MockError>386     pub fn add_indirect_chain(&mut self, len: u16) -> Result<(), MockError> {
387         let head_idx = self.desc_table.build_chain(1)?;
388 
389         // We just allocate the indirect table and forget about it for now.
390         let indirect_addr = self.alloc_indirect_chain(len)?;
391 
392         let mut desc = self.desc_table.load(head_idx)?;
393         desc.set_flags(VRING_DESC_F_INDIRECT as u16);
394         desc.set_addr(indirect_addr.raw_value());
395         desc.set_len(u32::from(len) * size_of::<Descriptor>() as u32);
396 
397         self.desc_table.store(head_idx, desc)?;
398         self.update_avail_idx(head_idx)
399     }
400 
401     /// Creates a new `Queue`, using the underlying memory regions represented
402     /// by the `MockSplitQueue`.
create_queue<Q: QueueT>(&self) -> Result<Q, Error>403     pub fn create_queue<Q: QueueT>(&self) -> Result<Q, Error> {
404         let mut q = Q::new(self.len)?;
405         q.set_size(self.len);
406         q.set_ready(true);
407         // we cannot directly set the u64 address, we need to compose it from low & high.
408         q.set_desc_table_address(
409             Some(self.desc_table_addr.0 as u32),
410             Some((self.desc_table_addr.0 >> 32) as u32),
411         );
412         q.set_avail_ring_address(
413             Some(self.avail_addr.0 as u32),
414             Some((self.avail_addr.0 >> 32) as u32),
415         );
416         q.set_used_ring_address(
417             Some(self.used_addr.0 as u32),
418             Some((self.used_addr.0 >> 32) as u32),
419         );
420         Ok(q)
421     }
422 
423     /// Writes multiple descriptor chains to the memory object of the queue, at the beginning of
424     /// the descriptor table, and returns the first `DescriptorChain` available.
build_multiple_desc_chains( &self, descs: &[Descriptor], ) -> Result<DescriptorChain<&M>, MockError>425     pub fn build_multiple_desc_chains(
426         &self,
427         descs: &[Descriptor],
428     ) -> Result<DescriptorChain<&M>, MockError> {
429         self.add_desc_chains(descs, 0)?;
430         self.create_queue::<Queue>()
431             .map_err(MockError::InvalidQueueParams)?
432             .iter(self.mem)
433             .map_err(MockError::InvalidQueueParams)?
434             .next()
435             .ok_or(MockError::InvalidNextAvail)
436     }
437 
438     /// Writes a single descriptor chain to the memory object of the queue, at the beginning of the
439     /// descriptor table, and returns the associated `DescriptorChain` object.
440     // This method ensures the next flags and values are set properly for the desired chain, but
441     // keeps the other characteristics of the input descriptors (`addr`, `len`, other flags).
442     // TODO: make this function work with a generic queue. For now that's not possible because
443     // we cannot create the descriptor chain from an iterator as iterator is not implemented for
444     // a generic T, just for `Queue`.
build_desc_chain(&self, descs: &[Descriptor]) -> Result<DescriptorChain<&M>, MockError>445     pub fn build_desc_chain(&self, descs: &[Descriptor]) -> Result<DescriptorChain<&M>, MockError> {
446         let mut modified_descs: Vec<Descriptor> = Vec::with_capacity(descs.len());
447         for (idx, desc) in descs.iter().enumerate() {
448             let (flags, next) = if idx == descs.len() - 1 {
449                 // Clear the NEXT flag if it was set. The value of the next field of the
450                 // Descriptor doesn't matter at this point.
451                 (desc.flags() & !VRING_DESC_F_NEXT as u16, 0)
452             } else {
453                 // Ensure that the next flag is set and that we are referring the following
454                 // descriptor. This ignores any value is actually present in `desc.next`.
455                 (desc.flags() | VRING_DESC_F_NEXT as u16, idx as u16 + 1)
456             };
457             modified_descs.push(Descriptor::new(desc.addr().0, desc.len(), flags, next));
458         }
459         self.build_multiple_desc_chains(&modified_descs[..])
460     }
461 
462     /// Adds descriptor chains to the memory object of the queue.
463     // `descs` represents a slice of `Descriptor` objects which are used to populate the chains, and
464     // `offset` is the index in the descriptor table where the chains should be added.
465     // The descriptor chain related information is written in memory starting with address 0.
466     // The `addr` fields of the input descriptors should start at a sufficiently
467     // greater location (i.e. 1MiB, or `0x10_0000`).
add_desc_chains(&self, descs: &[Descriptor], offset: u16) -> Result<(), MockError>468     pub fn add_desc_chains(&self, descs: &[Descriptor], offset: u16) -> Result<(), MockError> {
469         let mut new_entries = 0;
470         let avail_idx: u16 = self
471             .mem
472             .read_obj::<u16>(self.avail_addr().unchecked_add(2))
473             .map(u16::from_le)
474             .map_err(MockError::GuestMem)?;
475 
476         for (idx, desc) in descs.iter().enumerate() {
477             let i = idx as u16 + offset;
478             self.desc_table().store(i, *desc)?;
479 
480             if idx == 0 || descs[idx - 1].flags() & VRING_DESC_F_NEXT as u16 != 1 {
481                 // Update the available ring position.
482                 self.mem
483                     .write_obj(
484                         u16::to_le(i),
485                         self.avail_addr().unchecked_add(
486                             VIRTQ_AVAIL_RING_HEADER_SIZE
487                                 + (avail_idx + new_entries) as u64 * VIRTQ_AVAIL_ELEMENT_SIZE,
488                         ),
489                     )
490                     .map_err(MockError::GuestMem)?;
491                 new_entries += 1;
492             }
493         }
494 
495         // Increment `avail_idx`.
496         self.mem
497             .write_obj(
498                 u16::to_le(avail_idx + new_entries),
499                 self.avail_addr().unchecked_add(2),
500             )
501             .map_err(MockError::GuestMem)?;
502 
503         Ok(())
504     }
505 }
506