1 // Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 //
3 // Portions Copyright 2017 The Chromium OS Authors. All rights reserved.
4 //
5 // SPDX-License-Identifier: BSD-3-Clause
6 
7 //! Trait and wrapper for working with C defined FAM structures.
8 //!
9 //! In C 99 an array of unknown size may appear within a struct definition as the last member
10 //! (as long as there is at least one other named member).
11 //! This is known as a flexible array member (FAM).
12 //! Pre C99, the same behavior could be achieved using zero length arrays.
13 //!
14 //! Flexible Array Members are the go-to choice for working with large amounts of data
15 //! prefixed by header values.
16 //!
17 //! For example the KVM API has many structures of this kind.
18 
19 #[cfg(feature = "with-serde")]
20 use serde::de::{self, Deserialize, Deserializer, SeqAccess, Visitor};
21 #[cfg(feature = "with-serde")]
22 use serde::{ser::SerializeTuple, Serialize, Serializer};
23 use std::fmt;
24 #[cfg(feature = "with-serde")]
25 use std::marker::PhantomData;
26 use std::mem::{self, size_of};
27 
28 /// Errors associated with the [`FamStructWrapper`](struct.FamStructWrapper.html) struct.
29 #[derive(Clone, Debug, PartialEq, Eq)]
30 pub enum Error {
31     /// The max size has been exceeded
32     SizeLimitExceeded,
33 }
34 
35 impl std::error::Error for Error {}
36 
37 impl fmt::Display for Error {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result38     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
39         match self {
40             Self::SizeLimitExceeded => write!(f, "The max size has been exceeded"),
41         }
42     }
43 }
44 
45 /// Trait for accessing properties of C defined FAM structures.
46 ///
47 /// # Safety
48 ///
49 /// This is unsafe due to the number of constraints that aren't checked:
50 /// * the implementer should be a POD
51 /// * the implementor should contain a flexible array member of elements of type `Entry`
52 /// * `Entry` should be a POD
53 /// * the implementor should ensures that the FAM length as returned by [`FamStruct::len()`]
54 ///   always describes correctly the length of the flexible array member.
55 ///
56 /// Violating these may cause problems.
57 ///
58 /// # Example
59 ///
60 /// ```
61 /// use vmm_sys_util::fam::*;
62 ///
63 /// #[repr(C)]
64 /// #[derive(Default)]
65 /// pub struct __IncompleteArrayField<T>(::std::marker::PhantomData<T>, [T; 0]);
66 /// impl<T> __IncompleteArrayField<T> {
67 ///     #[inline]
68 ///     pub fn new() -> Self {
69 ///         __IncompleteArrayField(::std::marker::PhantomData, [])
70 ///     }
71 ///     #[inline]
72 ///     pub unsafe fn as_ptr(&self) -> *const T {
73 ///         ::std::mem::transmute(self)
74 ///     }
75 ///     #[inline]
76 ///     pub unsafe fn as_mut_ptr(&mut self) -> *mut T {
77 ///         ::std::mem::transmute(self)
78 ///     }
79 ///     #[inline]
80 ///     pub unsafe fn as_slice(&self, len: usize) -> &[T] {
81 ///         ::std::slice::from_raw_parts(self.as_ptr(), len)
82 ///     }
83 ///     #[inline]
84 ///     pub unsafe fn as_mut_slice(&mut self, len: usize) -> &mut [T] {
85 ///         ::std::slice::from_raw_parts_mut(self.as_mut_ptr(), len)
86 ///     }
87 /// }
88 ///
89 /// #[repr(C)]
90 /// #[derive(Default)]
91 /// struct MockFamStruct {
92 ///     pub len: u32,
93 ///     pub padding: u32,
94 ///     pub entries: __IncompleteArrayField<u32>,
95 /// }
96 ///
97 /// unsafe impl FamStruct for MockFamStruct {
98 ///     type Entry = u32;
99 ///
100 ///     fn len(&self) -> usize {
101 ///         self.len as usize
102 ///     }
103 ///
104 ///     unsafe fn set_len(&mut self, len: usize) {
105 ///         self.len = len as u32
106 ///     }
107 ///
108 ///     fn max_len() -> usize {
109 ///         100
110 ///     }
111 ///
112 ///     fn as_slice(&self) -> &[u32] {
113 ///         let len = self.len();
114 ///         unsafe { self.entries.as_slice(len) }
115 ///     }
116 ///
117 ///     fn as_mut_slice(&mut self) -> &mut [u32] {
118 ///         let len = self.len();
119 ///         unsafe { self.entries.as_mut_slice(len) }
120 ///     }
121 /// }
122 ///
123 /// type MockFamStructWrapper = FamStructWrapper<MockFamStruct>;
124 /// ```
125 #[allow(clippy::len_without_is_empty)]
126 pub unsafe trait FamStruct {
127     /// The type of the FAM entries
128     type Entry: PartialEq + Copy;
129 
130     /// Get the FAM length
131     ///
132     /// These type of structures contain a member that holds the FAM length.
133     /// This method will return the value of that member.
len(&self) -> usize134     fn len(&self) -> usize;
135 
136     /// Set the FAM length
137     ///
138     /// These type of structures contain a member that holds the FAM length.
139     /// This method will set the value of that member.
140     ///
141     /// # Safety
142     ///
143     /// The caller needs to ensure that `len` here reflects the correct number of entries of the
144     /// flexible array part of the struct.
set_len(&mut self, len: usize)145     unsafe fn set_len(&mut self, len: usize);
146 
147     /// Get max allowed FAM length
148     ///
149     /// This depends on each structure.
150     /// For example a structure representing the cpuid can contain at most 80 entries.
max_len() -> usize151     fn max_len() -> usize;
152 
153     /// Get the FAM entries as slice
as_slice(&self) -> &[Self::Entry]154     fn as_slice(&self) -> &[Self::Entry];
155 
156     /// Get the FAM entries as mut slice
as_mut_slice(&mut self) -> &mut [Self::Entry]157     fn as_mut_slice(&mut self) -> &mut [Self::Entry];
158 }
159 
160 /// A wrapper for [`FamStruct`](trait.FamStruct.html).
161 ///
162 /// It helps in treating a [`FamStruct`](trait.FamStruct.html) similarly to an actual `Vec`.
163 #[derive(Debug)]
164 pub struct FamStructWrapper<T: Default + FamStruct> {
165     // This variable holds the FamStruct structure. We use a `Vec<T>` to make the allocation
166     // large enough while still being aligned for `T`. Only the first element of `Vec<T>`
167     // will actually be used as a `T`. The remaining memory in the `Vec<T>` is for `entries`,
168     // which must be contiguous. Since the entries are of type `FamStruct::Entry` we must
169     // be careful to convert the desired capacity of the `FamStructWrapper`
170     // from `FamStruct::Entry` to `T` when reserving or releasing memory.
171     mem_allocator: Vec<T>,
172 }
173 
174 impl<T: Default + FamStruct> FamStructWrapper<T> {
175     /// Convert FAM len to `mem_allocator` len.
176     ///
177     /// Get the capacity required by mem_allocator in order to hold
178     /// the provided number of [`FamStruct::Entry`](trait.FamStruct.html#associatedtype.Entry).
179     /// Returns `None` if the required length would overflow usize.
mem_allocator_len(fam_len: usize) -> Option<usize>180     fn mem_allocator_len(fam_len: usize) -> Option<usize> {
181         let wrapper_size_in_bytes =
182             size_of::<T>().checked_add(fam_len.checked_mul(size_of::<T::Entry>())?)?;
183 
184         wrapper_size_in_bytes
185             .checked_add(size_of::<T>().checked_sub(1)?)?
186             .checked_div(size_of::<T>())
187     }
188 
189     /// Convert `mem_allocator` len to FAM len.
190     ///
191     /// Get the number of elements of type
192     /// [`FamStruct::Entry`](trait.FamStruct.html#associatedtype.Entry)
193     /// that fit in a mem_allocator of provided len.
fam_len(mem_allocator_len: usize) -> usize194     fn fam_len(mem_allocator_len: usize) -> usize {
195         if mem_allocator_len == 0 {
196             return 0;
197         }
198 
199         let array_size_in_bytes = (mem_allocator_len - 1) * size_of::<T>();
200         array_size_in_bytes / size_of::<T::Entry>()
201     }
202 
203     /// Create a new FamStructWrapper with `num_elements` elements.
204     ///
205     /// The elements will be zero-initialized. The type of the elements will be
206     /// [`FamStruct::Entry`](trait.FamStruct.html#associatedtype.Entry).
207     ///
208     /// # Arguments
209     ///
210     /// * `num_elements` - The number of elements in the FamStructWrapper.
211     ///
212     /// # Errors
213     ///
214     /// When `num_elements` is greater than the max possible len, it returns
215     /// `Error::SizeLimitExceeded`.
new(num_elements: usize) -> Result<FamStructWrapper<T>, Error>216     pub fn new(num_elements: usize) -> Result<FamStructWrapper<T>, Error> {
217         if num_elements > T::max_len() {
218             return Err(Error::SizeLimitExceeded);
219         }
220         let required_mem_allocator_capacity =
221             FamStructWrapper::<T>::mem_allocator_len(num_elements)
222                 .ok_or(Error::SizeLimitExceeded)?;
223 
224         let mut mem_allocator = Vec::with_capacity(required_mem_allocator_capacity);
225         mem_allocator.push(T::default());
226         for _ in 1..required_mem_allocator_capacity {
227             // SAFETY: Safe as long T follows the requirements of being POD.
228             mem_allocator.push(unsafe { mem::zeroed() })
229         }
230         // SAFETY: The flexible array part of the struct has `num_elements` capacity. We just
231         // initialized this in `mem_allocator`.
232         unsafe {
233             mem_allocator[0].set_len(num_elements);
234         }
235 
236         Ok(FamStructWrapper { mem_allocator })
237     }
238 
239     /// Create a new FamStructWrapper from a slice of elements.
240     ///
241     /// # Arguments
242     ///
243     /// * `entries` - The slice of [`FamStruct::Entry`](trait.FamStruct.html#associatedtype.Entry)
244     ///               entries.
245     ///
246     /// # Errors
247     ///
248     /// When the size of `entries` is greater than the max possible len, it returns
249     /// `Error::SizeLimitExceeded`.
from_entries(entries: &[T::Entry]) -> Result<FamStructWrapper<T>, Error>250     pub fn from_entries(entries: &[T::Entry]) -> Result<FamStructWrapper<T>, Error> {
251         let mut adapter = FamStructWrapper::<T>::new(entries.len())?;
252 
253         {
254             // SAFETY: We are not modifying the length of the FamStruct
255             let wrapper_entries = unsafe { adapter.as_mut_fam_struct().as_mut_slice() };
256             wrapper_entries.copy_from_slice(entries);
257         }
258 
259         Ok(adapter)
260     }
261 
262     /// Create a new FamStructWrapper from the raw content represented as `Vec<T>`.
263     ///
264     /// Sometimes we already have the raw content of an FAM struct represented as `Vec<T>`,
265     /// and want to use the FamStructWrapper as accessors.
266     ///
267     /// # Arguments
268     ///
269     /// * `content` - The raw content represented as `Vec[T]`.
270     ///
271     /// # Safety
272     ///
273     /// This function is unsafe because the caller needs to ensure that the raw content is
274     /// correctly layed out.
from_raw(content: Vec<T>) -> Self275     pub unsafe fn from_raw(content: Vec<T>) -> Self {
276         FamStructWrapper {
277             mem_allocator: content,
278         }
279     }
280 
281     /// Consume the FamStructWrapper and return the raw content as `Vec<T>`.
into_raw(self) -> Vec<T>282     pub fn into_raw(self) -> Vec<T> {
283         self.mem_allocator
284     }
285 
286     /// Get a reference to the actual [`FamStruct`](trait.FamStruct.html) instance.
as_fam_struct_ref(&self) -> &T287     pub fn as_fam_struct_ref(&self) -> &T {
288         &self.mem_allocator[0]
289     }
290 
291     /// Get a mut reference to the actual [`FamStruct`](trait.FamStruct.html) instance.
292     ///
293     /// # Safety
294     ///
295     /// Callers must not use the reference returned to modify the `len` filed of the underlying
296     /// `FamStruct`. See also the top-level documentation of [`FamStruct`].
as_mut_fam_struct(&mut self) -> &mut T297     pub unsafe fn as_mut_fam_struct(&mut self) -> &mut T {
298         &mut self.mem_allocator[0]
299     }
300 
301     /// Get a pointer to the [`FamStruct`](trait.FamStruct.html) instance.
302     ///
303     /// The caller must ensure that the fam_struct outlives the pointer this
304     /// function returns, or else it will end up pointing to garbage.
305     ///
306     /// Modifying the container referenced by this pointer may cause its buffer
307     /// to be reallocated, which would also make any pointers to it invalid.
as_fam_struct_ptr(&self) -> *const T308     pub fn as_fam_struct_ptr(&self) -> *const T {
309         self.as_fam_struct_ref()
310     }
311 
312     /// Get a mutable pointer to the [`FamStruct`](trait.FamStruct.html) instance.
313     ///
314     /// The caller must ensure that the fam_struct outlives the pointer this
315     /// function returns, or else it will end up pointing to garbage.
316     ///
317     /// Modifying the container referenced by this pointer may cause its buffer
318     /// to be reallocated, which would also make any pointers to it invalid.
as_mut_fam_struct_ptr(&mut self) -> *mut T319     pub fn as_mut_fam_struct_ptr(&mut self) -> *mut T {
320         // SAFETY: We do not change the length of the underlying FamStruct.
321         unsafe { self.as_mut_fam_struct() }
322     }
323 
324     /// Get the elements slice.
as_slice(&self) -> &[T::Entry]325     pub fn as_slice(&self) -> &[T::Entry] {
326         self.as_fam_struct_ref().as_slice()
327     }
328 
329     /// Get the mutable elements slice.
as_mut_slice(&mut self) -> &mut [T::Entry]330     pub fn as_mut_slice(&mut self) -> &mut [T::Entry] {
331         // SAFETY: We do not change the length of the underlying FamStruct.
332         unsafe { self.as_mut_fam_struct() }.as_mut_slice()
333     }
334 
335     /// Get the number of elements of type `FamStruct::Entry` currently in the vec.
len(&self) -> usize336     fn len(&self) -> usize {
337         self.as_fam_struct_ref().len()
338     }
339 
340     /// Get the capacity of the `FamStructWrapper`
341     ///
342     /// The capacity is measured in elements of type `FamStruct::Entry`.
capacity(&self) -> usize343     fn capacity(&self) -> usize {
344         FamStructWrapper::<T>::fam_len(self.mem_allocator.capacity())
345     }
346 
347     /// Reserve additional capacity.
348     ///
349     /// Reserve capacity for at least `additional` more
350     /// [`FamStruct::Entry`](trait.FamStruct.html#associatedtype.Entry) elements.
351     ///
352     /// If the capacity is already reserved, this method doesn't do anything.
353     /// If not this will trigger a reallocation of the underlying buffer.
reserve(&mut self, additional: usize) -> Result<(), Error>354     fn reserve(&mut self, additional: usize) -> Result<(), Error> {
355         let desired_capacity = self.len() + additional;
356         if desired_capacity <= self.capacity() {
357             return Ok(());
358         }
359 
360         let current_mem_allocator_len = self.mem_allocator.len();
361         let required_mem_allocator_len = FamStructWrapper::<T>::mem_allocator_len(desired_capacity)
362             .ok_or(Error::SizeLimitExceeded)?;
363         let additional_mem_allocator_len = required_mem_allocator_len - current_mem_allocator_len;
364 
365         self.mem_allocator.reserve(additional_mem_allocator_len);
366 
367         Ok(())
368     }
369 
370     /// Update the length of the FamStructWrapper.
371     ///
372     /// The length of `self` will be updated to the specified value.
373     /// The length of the `T` structure and of `self.mem_allocator` will be updated accordingly.
374     /// If the len is increased additional capacity will be reserved.
375     /// If the len is decreased the unnecessary memory will be deallocated.
376     ///
377     /// This method might trigger reallocations of the underlying buffer.
378     ///
379     /// # Errors
380     ///
381     /// When len is greater than the max possible len it returns Error::SizeLimitExceeded.
set_len(&mut self, len: usize) -> Result<(), Error>382     fn set_len(&mut self, len: usize) -> Result<(), Error> {
383         let additional_elements = isize::try_from(len)
384             .and_then(|len| isize::try_from(self.len()).map(|self_len| len - self_len))
385             .map_err(|_| Error::SizeLimitExceeded)?;
386 
387         // If len == self.len there's nothing to do.
388         if additional_elements == 0 {
389             return Ok(());
390         }
391 
392         // If the len needs to be increased:
393         if additional_elements > 0 {
394             // Check if the new len is valid.
395             if len > T::max_len() {
396                 return Err(Error::SizeLimitExceeded);
397             }
398             // Reserve additional capacity.
399             self.reserve(additional_elements as usize)?;
400         }
401 
402         let current_mem_allocator_len = self.mem_allocator.len();
403         let required_mem_allocator_len =
404             FamStructWrapper::<T>::mem_allocator_len(len).ok_or(Error::SizeLimitExceeded)?;
405         // Update the len of the `mem_allocator`.
406         // SAFETY: This is safe since enough capacity has been reserved.
407         unsafe {
408             self.mem_allocator.set_len(required_mem_allocator_len);
409         }
410         // Zero-initialize the additional elements if any.
411         for i in current_mem_allocator_len..required_mem_allocator_len {
412             // SAFETY: Safe as long as the trait is only implemented for POD. This is a requirement
413             // for the trait implementation.
414             self.mem_allocator[i] = unsafe { mem::zeroed() }
415         }
416         // Update the len of the underlying `FamStruct`.
417         // SAFETY: We just adjusted the memory for the underlying `mem_allocator` to hold `len`
418         // entries.
419         unsafe {
420             self.as_mut_fam_struct().set_len(len);
421         }
422 
423         // If the len needs to be decreased, deallocate unnecessary memory
424         if additional_elements < 0 {
425             self.mem_allocator.shrink_to_fit();
426         }
427 
428         Ok(())
429     }
430 
431     /// Append an element.
432     ///
433     /// # Arguments
434     ///
435     /// * `entry` - The element that will be appended to the end of the collection.
436     ///
437     /// # Errors
438     ///
439     /// When len is already equal to max possible len it returns Error::SizeLimitExceeded.
push(&mut self, entry: T::Entry) -> Result<(), Error>440     pub fn push(&mut self, entry: T::Entry) -> Result<(), Error> {
441         let new_len = self.len() + 1;
442         self.set_len(new_len)?;
443         self.as_mut_slice()[new_len - 1] = entry;
444 
445         Ok(())
446     }
447 
448     /// Retain only the elements specified by the predicate.
449     ///
450     /// # Arguments
451     ///
452     /// * `f` - The function used to evaluate whether an entry will be kept or not.
453     ///         When `f` returns `true` the entry is kept.
retain<P>(&mut self, mut f: P) where P: FnMut(&T::Entry) -> bool,454     pub fn retain<P>(&mut self, mut f: P)
455     where
456         P: FnMut(&T::Entry) -> bool,
457     {
458         let mut num_kept_entries = 0;
459         {
460             let entries = self.as_mut_slice();
461             for entry_idx in 0..entries.len() {
462                 let keep = f(&entries[entry_idx]);
463                 if keep {
464                     entries[num_kept_entries] = entries[entry_idx];
465                     num_kept_entries += 1;
466                 }
467             }
468         }
469 
470         // This is safe since this method is not increasing the len
471         self.set_len(num_kept_entries).expect("invalid length");
472     }
473 }
474 
475 impl<T: Default + FamStruct + PartialEq> PartialEq for FamStructWrapper<T> {
eq(&self, other: &FamStructWrapper<T>) -> bool476     fn eq(&self, other: &FamStructWrapper<T>) -> bool {
477         self.as_fam_struct_ref() == other.as_fam_struct_ref() && self.as_slice() == other.as_slice()
478     }
479 }
480 
481 impl<T: Default + FamStruct> Clone for FamStructWrapper<T> {
clone(&self) -> Self482     fn clone(&self) -> Self {
483         // The number of entries (self.as_slice().len()) can't be > T::max_len() since `self` is a
484         // valid `FamStructWrapper`. This makes the .unwrap() safe.
485         let required_mem_allocator_capacity =
486             FamStructWrapper::<T>::mem_allocator_len(self.as_slice().len()).unwrap();
487 
488         let mut mem_allocator = Vec::with_capacity(required_mem_allocator_capacity);
489 
490         // SAFETY: This is safe as long as the requirements for the `FamStruct` trait to be safe
491         // are met (the implementing type and the entries elements are POD, therefore `Copy`, so
492         // memory safety can't be violated by the ownership of `fam_struct`). It is also safe
493         // because we're trying to read a T from a `&T` that is pointing to a properly initialized
494         // and aligned T.
495         unsafe {
496             let fam_struct: T = std::ptr::read(self.as_fam_struct_ref());
497             mem_allocator.push(fam_struct);
498         }
499         for _ in 1..required_mem_allocator_capacity {
500             mem_allocator.push(
501                 // SAFETY: This is safe as long as T respects the FamStruct trait and is a POD.
502                 unsafe { mem::zeroed() },
503             )
504         }
505 
506         let mut adapter = FamStructWrapper { mem_allocator };
507         {
508             let wrapper_entries = adapter.as_mut_slice();
509             wrapper_entries.copy_from_slice(self.as_slice());
510         }
511         adapter
512     }
513 }
514 
515 impl<T: Default + FamStruct> From<Vec<T>> for FamStructWrapper<T> {
from(vec: Vec<T>) -> Self516     fn from(vec: Vec<T>) -> Self {
517         FamStructWrapper { mem_allocator: vec }
518     }
519 }
520 
521 #[cfg(feature = "with-serde")]
522 impl<T: Default + FamStruct + Serialize> Serialize for FamStructWrapper<T>
523 where
524     <T as FamStruct>::Entry: serde::Serialize,
525 {
serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error> where S: Serializer,526     fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
527     where
528         S: Serializer,
529     {
530         let mut s = serializer.serialize_tuple(2)?;
531         s.serialize_element(self.as_fam_struct_ref())?;
532         s.serialize_element(self.as_slice())?;
533         s.end()
534     }
535 }
536 
537 #[cfg(feature = "with-serde")]
538 impl<'de, T: Default + FamStruct + Deserialize<'de>> Deserialize<'de> for FamStructWrapper<T>
539 where
540     <T as FamStruct>::Entry: std::marker::Copy + serde::Deserialize<'de>,
541 {
deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error> where D: Deserializer<'de>,542     fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
543     where
544         D: Deserializer<'de>,
545     {
546         struct FamStructWrapperVisitor<X> {
547             dummy: PhantomData<X>,
548         }
549 
550         impl<'de, X: Default + FamStruct + Deserialize<'de>> Visitor<'de> for FamStructWrapperVisitor<X>
551         where
552             <X as FamStruct>::Entry: std::marker::Copy + serde::Deserialize<'de>,
553         {
554             type Value = FamStructWrapper<X>;
555 
556             fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
557                 formatter.write_str("FamStructWrapper")
558             }
559 
560             fn visit_seq<V>(self, mut seq: V) -> Result<FamStructWrapper<X>, V::Error>
561             where
562                 V: SeqAccess<'de>,
563             {
564                 use serde::de::Error;
565 
566                 let header: X = seq
567                     .next_element()?
568                     .ok_or_else(|| de::Error::invalid_length(0, &self))?;
569                 let entries: Vec<X::Entry> = seq
570                     .next_element()?
571                     .ok_or_else(|| de::Error::invalid_length(1, &self))?;
572 
573                 if header.len() != entries.len() {
574                     let msg = format!(
575                         "Mismatch between length of FAM specified in FamStruct header ({}) \
576                          and actual size of FAM ({})",
577                         header.len(),
578                         entries.len()
579                     );
580                     return Err(V::Error::custom(msg));
581                 }
582 
583                 let mut result: Self::Value = FamStructWrapper::from_entries(entries.as_slice())
584                     .map_err(|e| V::Error::custom(format!("{:?}", e)))?;
585                 result.mem_allocator[0] = header;
586                 Ok(result)
587             }
588         }
589 
590         deserializer.deserialize_tuple(2, FamStructWrapperVisitor { dummy: PhantomData })
591     }
592 }
593 
594 /// Generate `FamStruct` implementation for structs with flexible array member.
595 #[macro_export]
596 macro_rules! generate_fam_struct_impl {
597     ($struct_type: ty, $entry_type: ty, $entries_name: ident,
598      $field_type: ty, $field_name: ident, $max: expr) => {
599         unsafe impl FamStruct for $struct_type {
600             type Entry = $entry_type;
601 
602             fn len(&self) -> usize {
603                 self.$field_name as usize
604             }
605 
606             unsafe fn set_len(&mut self, len: usize) {
607                 self.$field_name = len as $field_type;
608             }
609 
610             fn max_len() -> usize {
611                 $max
612             }
613 
614             fn as_slice(&self) -> &[<Self as FamStruct>::Entry] {
615                 let len = self.len();
616                 unsafe { self.$entries_name.as_slice(len) }
617             }
618 
619             fn as_mut_slice(&mut self) -> &mut [<Self as FamStruct>::Entry] {
620                 let len = self.len();
621                 unsafe { self.$entries_name.as_mut_slice(len) }
622             }
623         }
624     };
625 }
626 
627 #[cfg(test)]
628 mod tests {
629     #![allow(clippy::undocumented_unsafe_blocks)]
630 
631     #[cfg(feature = "with-serde")]
632     use serde_derive::{Deserialize, Serialize};
633 
634     use super::*;
635 
636     const MAX_LEN: usize = 100;
637 
638     #[repr(C)]
639     #[derive(Default, Debug, PartialEq, Eq)]
640     pub struct __IncompleteArrayField<T>(::std::marker::PhantomData<T>, [T; 0]);
641     impl<T> __IncompleteArrayField<T> {
642         #[inline]
new() -> Self643         pub fn new() -> Self {
644             __IncompleteArrayField(::std::marker::PhantomData, [])
645         }
646         #[inline]
as_ptr(&self) -> *const T647         pub unsafe fn as_ptr(&self) -> *const T {
648             self as *const __IncompleteArrayField<T> as *const T
649         }
650         #[inline]
as_mut_ptr(&mut self) -> *mut T651         pub unsafe fn as_mut_ptr(&mut self) -> *mut T {
652             self as *mut __IncompleteArrayField<T> as *mut T
653         }
654         #[inline]
as_slice(&self, len: usize) -> &[T]655         pub unsafe fn as_slice(&self, len: usize) -> &[T] {
656             ::std::slice::from_raw_parts(self.as_ptr(), len)
657         }
658         #[inline]
as_mut_slice(&mut self, len: usize) -> &mut [T]659         pub unsafe fn as_mut_slice(&mut self, len: usize) -> &mut [T] {
660             ::std::slice::from_raw_parts_mut(self.as_mut_ptr(), len)
661         }
662     }
663 
664     #[cfg(feature = "with-serde")]
665     impl<T> Serialize for __IncompleteArrayField<T> {
serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error> where S: Serializer,666         fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
667         where
668             S: Serializer,
669         {
670             [0u8; 0].serialize(serializer)
671         }
672     }
673 
674     #[cfg(feature = "with-serde")]
675     impl<'de, T> Deserialize<'de> for __IncompleteArrayField<T> {
deserialize<D>(_: D) -> std::result::Result<Self, D::Error> where D: Deserializer<'de>,676         fn deserialize<D>(_: D) -> std::result::Result<Self, D::Error>
677         where
678             D: Deserializer<'de>,
679         {
680             Ok(__IncompleteArrayField::new())
681         }
682     }
683 
684     #[repr(C)]
685     #[derive(Default, PartialEq)]
686     struct MockFamStruct {
687         pub len: u32,
688         pub padding: u32,
689         pub entries: __IncompleteArrayField<u32>,
690     }
691 
692     generate_fam_struct_impl!(MockFamStruct, u32, entries, u32, len, 100);
693 
694     type MockFamStructWrapper = FamStructWrapper<MockFamStruct>;
695 
696     const ENTRIES_OFFSET: usize = 2;
697 
698     const FAM_LEN_TO_MEM_ALLOCATOR_LEN: &[(usize, usize)] = &[
699         (0, 1),
700         (1, 2),
701         (2, 2),
702         (3, 3),
703         (4, 3),
704         (5, 4),
705         (10, 6),
706         (50, 26),
707         (100, 51),
708     ];
709 
710     const MEM_ALLOCATOR_LEN_TO_FAM_LEN: &[(usize, usize)] = &[
711         (0, 0),
712         (1, 0),
713         (2, 2),
714         (3, 4),
715         (4, 6),
716         (5, 8),
717         (10, 18),
718         (50, 98),
719         (100, 198),
720     ];
721 
722     #[test]
test_mem_allocator_len()723     fn test_mem_allocator_len() {
724         for pair in FAM_LEN_TO_MEM_ALLOCATOR_LEN {
725             let fam_len = pair.0;
726             let mem_allocator_len = pair.1;
727             assert_eq!(
728                 Some(mem_allocator_len),
729                 MockFamStructWrapper::mem_allocator_len(fam_len)
730             );
731         }
732     }
733 
734     #[repr(C)]
735     #[derive(Default, PartialEq)]
736     struct MockFamStructU8 {
737         pub len: u32,
738         pub padding: u32,
739         pub entries: __IncompleteArrayField<u8>,
740     }
741     generate_fam_struct_impl!(MockFamStructU8, u8, entries, u32, len, 100);
742     type MockFamStructWrapperU8 = FamStructWrapper<MockFamStructU8>;
743     #[test]
test_invalid_type_conversion()744     fn test_invalid_type_conversion() {
745         let mut adapter = MockFamStructWrapperU8::new(10).unwrap();
746         assert!(matches!(
747             adapter.set_len(0xffff_ffff_ffff_ff00),
748             Err(Error::SizeLimitExceeded)
749         ));
750     }
751 
752     #[test]
test_wrapper_len()753     fn test_wrapper_len() {
754         for pair in MEM_ALLOCATOR_LEN_TO_FAM_LEN {
755             let mem_allocator_len = pair.0;
756             let fam_len = pair.1;
757             assert_eq!(fam_len, MockFamStructWrapper::fam_len(mem_allocator_len));
758         }
759     }
760 
761     #[test]
test_new()762     fn test_new() {
763         let num_entries = 10;
764 
765         let adapter = MockFamStructWrapper::new(num_entries).unwrap();
766         assert_eq!(num_entries, adapter.capacity());
767 
768         let u32_slice = unsafe {
769             std::slice::from_raw_parts(
770                 adapter.as_fam_struct_ptr() as *const u32,
771                 num_entries + ENTRIES_OFFSET,
772             )
773         };
774         assert_eq!(num_entries, u32_slice[0] as usize);
775         for entry in u32_slice[1..].iter() {
776             assert_eq!(*entry, 0);
777         }
778 
779         // It's okay to create a `FamStructWrapper` with the maximum allowed number of entries.
780         let adapter = MockFamStructWrapper::new(MockFamStruct::max_len()).unwrap();
781         assert_eq!(MockFamStruct::max_len(), adapter.capacity());
782 
783         assert!(matches!(
784             MockFamStructWrapper::new(MockFamStruct::max_len() + 1),
785             Err(Error::SizeLimitExceeded)
786         ));
787     }
788 
789     #[test]
test_from_entries()790     fn test_from_entries() {
791         let num_entries: usize = 10;
792 
793         let mut entries = Vec::new();
794         for i in 0..num_entries {
795             entries.push(i as u32);
796         }
797 
798         let adapter = MockFamStructWrapper::from_entries(entries.as_slice()).unwrap();
799         let u32_slice = unsafe {
800             std::slice::from_raw_parts(
801                 adapter.as_fam_struct_ptr() as *const u32,
802                 num_entries + ENTRIES_OFFSET,
803             )
804         };
805         assert_eq!(num_entries, u32_slice[0] as usize);
806         for (i, &value) in entries.iter().enumerate().take(num_entries) {
807             assert_eq!(adapter.as_slice()[i], value);
808         }
809 
810         let mut entries = Vec::new();
811         for i in 0..MockFamStruct::max_len() + 1 {
812             entries.push(i as u32);
813         }
814 
815         // Can't create a `FamStructWrapper` with a number of entries > MockFamStruct::max_len().
816         assert!(matches!(
817             MockFamStructWrapper::from_entries(entries.as_slice()),
818             Err(Error::SizeLimitExceeded)
819         ));
820     }
821 
822     #[test]
test_entries_slice()823     fn test_entries_slice() {
824         let num_entries = 10;
825         let mut adapter = MockFamStructWrapper::new(num_entries).unwrap();
826 
827         let expected_slice = &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
828 
829         {
830             let mut_entries_slice = adapter.as_mut_slice();
831             mut_entries_slice.copy_from_slice(expected_slice);
832         }
833 
834         let u32_slice = unsafe {
835             std::slice::from_raw_parts(
836                 adapter.as_fam_struct_ptr() as *const u32,
837                 num_entries + ENTRIES_OFFSET,
838             )
839         };
840         assert_eq!(expected_slice, &u32_slice[ENTRIES_OFFSET..]);
841         assert_eq!(expected_slice, adapter.as_slice());
842     }
843 
844     #[test]
test_reserve()845     fn test_reserve() {
846         let mut adapter = MockFamStructWrapper::new(0).unwrap();
847 
848         // test that the right capacity is reserved
849         for pair in FAM_LEN_TO_MEM_ALLOCATOR_LEN {
850             let num_elements = pair.0;
851             let required_mem_allocator_len = pair.1;
852 
853             adapter.reserve(num_elements).unwrap();
854 
855             assert!(adapter.mem_allocator.capacity() >= required_mem_allocator_len);
856             assert_eq!(0, adapter.len());
857             assert!(adapter.capacity() >= num_elements);
858         }
859 
860         // test that when the capacity is already reserved, the method doesn't do anything
861         let current_capacity = adapter.capacity();
862         adapter.reserve(current_capacity - 1).unwrap();
863         assert_eq!(current_capacity, adapter.capacity());
864     }
865 
866     #[test]
test_set_len()867     fn test_set_len() {
868         let mut desired_len = 0;
869         let mut adapter = MockFamStructWrapper::new(desired_len).unwrap();
870 
871         // keep initial len
872         assert!(adapter.set_len(desired_len).is_ok());
873         assert_eq!(adapter.len(), desired_len);
874 
875         // increase len
876         desired_len = 10;
877         assert!(adapter.set_len(desired_len).is_ok());
878         // check that the len has been increased and zero-initialized elements have been added
879         assert_eq!(adapter.len(), desired_len);
880         for element in adapter.as_slice() {
881             assert_eq!(*element, 0_u32);
882         }
883 
884         // decrease len
885         desired_len = 5;
886         assert!(adapter.set_len(desired_len).is_ok());
887         assert_eq!(adapter.len(), desired_len);
888     }
889 
890     #[test]
test_push()891     fn test_push() {
892         let mut adapter = MockFamStructWrapper::new(0).unwrap();
893 
894         for i in 0..MAX_LEN {
895             assert!(adapter.push(i as u32).is_ok());
896             assert_eq!(adapter.as_slice()[i], i as u32);
897             assert_eq!(adapter.len(), i + 1);
898             assert!(
899                 adapter.mem_allocator.capacity()
900                     >= MockFamStructWrapper::mem_allocator_len(i + 1).unwrap()
901             );
902         }
903 
904         assert!(adapter.push(0).is_err());
905     }
906 
907     #[test]
test_retain()908     fn test_retain() {
909         let mut adapter = MockFamStructWrapper::new(0).unwrap();
910 
911         let mut num_retained_entries = 0;
912         for i in 0..MAX_LEN {
913             assert!(adapter.push(i as u32).is_ok());
914             if i % 2 == 0 {
915                 num_retained_entries += 1;
916             }
917         }
918 
919         adapter.retain(|entry| entry % 2 == 0);
920 
921         for entry in adapter.as_slice().iter() {
922             assert_eq!(0, entry % 2);
923         }
924         assert_eq!(adapter.len(), num_retained_entries);
925         assert!(
926             adapter.mem_allocator.capacity()
927                 >= MockFamStructWrapper::mem_allocator_len(num_retained_entries).unwrap()
928         );
929     }
930 
931     #[test]
test_partial_eq()932     fn test_partial_eq() {
933         let mut wrapper_1 = MockFamStructWrapper::new(0).unwrap();
934         let mut wrapper_2 = MockFamStructWrapper::new(0).unwrap();
935         let mut wrapper_3 = MockFamStructWrapper::new(0).unwrap();
936 
937         for i in 0..MAX_LEN {
938             assert!(wrapper_1.push(i as u32).is_ok());
939             assert!(wrapper_2.push(i as u32).is_ok());
940             assert!(wrapper_3.push(0).is_ok());
941         }
942 
943         assert!(wrapper_1 == wrapper_2);
944         assert!(wrapper_1 != wrapper_3);
945     }
946 
947     #[test]
test_clone()948     fn test_clone() {
949         let mut adapter = MockFamStructWrapper::new(0).unwrap();
950 
951         for i in 0..MAX_LEN {
952             assert!(adapter.push(i as u32).is_ok());
953         }
954 
955         assert!(adapter == adapter.clone());
956     }
957 
958     #[test]
test_raw_content()959     fn test_raw_content() {
960         let data = vec![
961             MockFamStruct {
962                 len: 2,
963                 padding: 5,
964                 entries: __IncompleteArrayField::new(),
965             },
966             MockFamStruct {
967                 len: 0xA5,
968                 padding: 0x1e,
969                 entries: __IncompleteArrayField::new(),
970             },
971         ];
972 
973         let mut wrapper = unsafe { MockFamStructWrapper::from_raw(data) };
974         {
975             let payload = wrapper.as_slice();
976             assert_eq!(payload[0], 0xA5);
977             assert_eq!(payload[1], 0x1e);
978         }
979         assert_eq!(unsafe { wrapper.as_mut_fam_struct() }.padding, 5);
980         let data = wrapper.into_raw();
981         assert_eq!(data[0].len, 2);
982         assert_eq!(data[0].padding, 5);
983     }
984 
985     #[cfg(feature = "with-serde")]
986     #[test]
test_ser_deser()987     fn test_ser_deser() {
988         #[repr(C)]
989         #[derive(Default, PartialEq)]
990         #[cfg_attr(feature = "with-serde", derive(Deserialize, Serialize))]
991         struct Message {
992             pub len: u32,
993             pub padding: u32,
994             pub value: u32,
995             #[cfg_attr(feature = "with-serde", serde(skip))]
996             pub entries: __IncompleteArrayField<u32>,
997         }
998 
999         generate_fam_struct_impl!(Message, u32, entries, u32, len, 100);
1000 
1001         type MessageFamStructWrapper = FamStructWrapper<Message>;
1002 
1003         let data = vec![
1004             Message {
1005                 len: 2,
1006                 padding: 0,
1007                 value: 42,
1008                 entries: __IncompleteArrayField::new(),
1009             },
1010             Message {
1011                 len: 0xA5,
1012                 padding: 0x1e,
1013                 value: 0,
1014                 entries: __IncompleteArrayField::new(),
1015             },
1016         ];
1017 
1018         let wrapper = unsafe { MessageFamStructWrapper::from_raw(data) };
1019         let data_ser = serde_json::to_string(&wrapper).unwrap();
1020         assert_eq!(
1021             data_ser,
1022             "[{\"len\":2,\"padding\":0,\"value\":42},[165,30]]"
1023         );
1024         let data_deser =
1025             serde_json::from_str::<MessageFamStructWrapper>(data_ser.as_str()).unwrap();
1026         assert!(wrapper.eq(&data_deser));
1027 
1028         let bad_data_ser = r#"{"foo": "bar"}"#;
1029         assert!(serde_json::from_str::<MessageFamStructWrapper>(bad_data_ser).is_err());
1030 
1031         #[repr(C)]
1032         #[derive(Default)]
1033         #[cfg_attr(feature = "with-serde", derive(Deserialize, Serialize))]
1034         struct Message2 {
1035             pub len: u32,
1036             pub padding: u32,
1037             pub value: u32,
1038             #[cfg_attr(feature = "with-serde", serde(skip))]
1039             pub entries: __IncompleteArrayField<u32>,
1040         }
1041 
1042         // Maximum number of entries = 1, so the deserialization should fail because of this reason.
1043         generate_fam_struct_impl!(Message2, u32, entries, u32, len, 1);
1044 
1045         type Message2FamStructWrapper = FamStructWrapper<Message2>;
1046         assert!(serde_json::from_str::<Message2FamStructWrapper>(data_ser.as_str()).is_err());
1047     }
1048 
1049     #[test]
test_clone_multiple_fields()1050     fn test_clone_multiple_fields() {
1051         #[derive(Default, PartialEq)]
1052         #[repr(C)]
1053         struct Foo {
1054             index: u32,
1055             length: u16,
1056             flags: u32,
1057             entries: __IncompleteArrayField<u32>,
1058         }
1059 
1060         generate_fam_struct_impl!(Foo, u32, entries, u16, length, 100);
1061 
1062         type FooFamStructWrapper = FamStructWrapper<Foo>;
1063 
1064         let mut wrapper = FooFamStructWrapper::new(0).unwrap();
1065         // SAFETY: We do play with length here, but that's just for testing purposes :)
1066         unsafe {
1067             wrapper.as_mut_fam_struct().index = 1;
1068             wrapper.as_mut_fam_struct().flags = 2;
1069             wrapper.as_mut_fam_struct().length = 3;
1070             wrapper.push(3).unwrap();
1071             wrapper.push(14).unwrap();
1072             assert_eq!(wrapper.as_slice().len(), 3 + 2);
1073             assert_eq!(wrapper.as_slice()[3], 3);
1074             assert_eq!(wrapper.as_slice()[3 + 1], 14);
1075 
1076             let mut wrapper2 = wrapper.clone();
1077             assert_eq!(
1078                 wrapper.as_mut_fam_struct().index,
1079                 wrapper2.as_mut_fam_struct().index
1080             );
1081             assert_eq!(
1082                 wrapper.as_mut_fam_struct().length,
1083                 wrapper2.as_mut_fam_struct().length
1084             );
1085             assert_eq!(
1086                 wrapper.as_mut_fam_struct().flags,
1087                 wrapper2.as_mut_fam_struct().flags
1088             );
1089             assert_eq!(wrapper.as_slice(), wrapper2.as_slice());
1090             assert_eq!(
1091                 wrapper2.as_slice().len(),
1092                 wrapper2.as_mut_fam_struct().length as usize
1093             );
1094             assert!(wrapper == wrapper2);
1095 
1096             wrapper.as_mut_fam_struct().index = 3;
1097             assert!(wrapper != wrapper2);
1098 
1099             wrapper.as_mut_fam_struct().length = 7;
1100             assert!(wrapper != wrapper2);
1101 
1102             wrapper.push(1).unwrap();
1103             assert_eq!(wrapper.as_mut_fam_struct().length, 8);
1104             assert!(wrapper != wrapper2);
1105 
1106             let mut wrapper2 = wrapper.clone();
1107             assert!(wrapper == wrapper2);
1108 
1109             // Dropping the original variable should not affect its clone.
1110             drop(wrapper);
1111             assert_eq!(wrapper2.as_mut_fam_struct().index, 3);
1112             assert_eq!(wrapper2.as_mut_fam_struct().length, 8);
1113             assert_eq!(wrapper2.as_mut_fam_struct().flags, 2);
1114             assert_eq!(wrapper2.as_slice(), [0, 0, 0, 3, 14, 0, 0, 1]);
1115         }
1116     }
1117 
1118     #[cfg(feature = "with-serde")]
1119     #[test]
test_bad_deserialize()1120     fn test_bad_deserialize() {
1121         #[repr(C)]
1122         #[derive(Default, Debug, PartialEq, Serialize, Deserialize)]
1123         struct Foo {
1124             pub len: u32,
1125             pub padding: u32,
1126             pub entries: __IncompleteArrayField<u32>,
1127         }
1128 
1129         generate_fam_struct_impl!(Foo, u32, entries, u32, len, 100);
1130 
1131         let state = FamStructWrapper::<Foo>::new(0).unwrap();
1132         let mut bytes = bincode::serialize(&state).unwrap();
1133 
1134         // The `len` field of the header is the first to be serialized.
1135         // Writing at position 0 of the serialized data should change its value.
1136         bytes[0] = 255;
1137 
1138         assert!(
1139             matches!(bincode::deserialize::<FamStructWrapper<Foo>>(&bytes).map_err(|boxed| *boxed), Err(bincode::ErrorKind::Custom(s)) if s == *"Mismatch between length of FAM specified in FamStruct header (255) and actual size of FAM (0)")
1140         );
1141     }
1142 }
1143