1 // Copyright 2023 Google LLC
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 use std::convert::TryFrom;
16 use std::convert::TryInto;
17 use std::ops::Deref;
18 
19 #[derive(Debug)]
20 pub enum ParseError {
21     InvalidEnumValue,
22     DivisionFailure,
23     ArithmeticOverflow,
24     OutOfBoundsAccess,
25     MisalignedPayload,
26 }
27 
28 #[derive(Clone, Copy, Debug)]
29 pub struct BitSlice<'a> {
30     // note: the offsets are ENTIRELY UNRELATED to the size of this struct,
31     // so indexing needs to be checked to avoid panics
32     backing: &'a [u8],
33 
34     // invariant: end_bit_offset >= start_bit_offset, so subtraction will NEVER wrap
35     start_bit_offset: usize,
36     end_bit_offset: usize,
37 }
38 
39 #[derive(Clone, Copy, Debug)]
40 pub struct SizedBitSlice<'a>(BitSlice<'a>);
41 
42 impl<'a> BitSlice<'a> {
offset(&self, offset: usize) -> Result<BitSlice<'a>, ParseError>43     pub fn offset(&self, offset: usize) -> Result<BitSlice<'a>, ParseError> {
44         if self.end_bit_offset - self.start_bit_offset < offset {
45             return Err(ParseError::OutOfBoundsAccess);
46         }
47         Ok(Self {
48             backing: self.backing,
49             start_bit_offset: self
50                 .start_bit_offset
51                 .checked_add(offset)
52                 .ok_or(ParseError::ArithmeticOverflow)?,
53             end_bit_offset: self.end_bit_offset,
54         })
55     }
56 
slice(&self, len: usize) -> Result<SizedBitSlice<'a>, ParseError>57     pub fn slice(&self, len: usize) -> Result<SizedBitSlice<'a>, ParseError> {
58         if self.end_bit_offset - self.start_bit_offset < len {
59             return Err(ParseError::OutOfBoundsAccess);
60         }
61         Ok(SizedBitSlice(Self {
62             backing: self.backing,
63             start_bit_offset: self.start_bit_offset,
64             end_bit_offset: self
65                 .start_bit_offset
66                 .checked_add(len)
67                 .ok_or(ParseError::ArithmeticOverflow)?,
68         }))
69     }
70 
byte_at(&self, index: usize) -> Result<u8, ParseError>71     fn byte_at(&self, index: usize) -> Result<u8, ParseError> {
72         self.backing.get(index).ok_or(ParseError::OutOfBoundsAccess).copied()
73     }
74 }
75 
76 impl<'a> Deref for SizedBitSlice<'a> {
77     type Target = BitSlice<'a>;
78 
deref(&self) -> &Self::Target79     fn deref(&self) -> &Self::Target {
80         &self.0
81     }
82 }
83 
84 impl<'a> From<SizedBitSlice<'a>> for BitSlice<'a> {
from(x: SizedBitSlice<'a>) -> Self85     fn from(x: SizedBitSlice<'a>) -> Self {
86         *x
87     }
88 }
89 
90 impl<'a, 'b> From<&'b [u8]> for SizedBitSlice<'a>
91 where
92     'b: 'a,
93 {
from(backing: &'a [u8]) -> Self94     fn from(backing: &'a [u8]) -> Self {
95         Self(BitSlice { backing, start_bit_offset: 0, end_bit_offset: backing.len() * 8 })
96     }
97 }
98 
99 impl<'a> SizedBitSlice<'a> {
try_parse<T: TryFrom<u64>>(&self) -> Result<T, ParseError>100     pub fn try_parse<T: TryFrom<u64>>(&self) -> Result<T, ParseError> {
101         if self.end_bit_offset < self.start_bit_offset {
102             return Err(ParseError::OutOfBoundsAccess);
103         }
104         let size_in_bits = self.end_bit_offset - self.start_bit_offset;
105 
106         // fields that fit into a u64 don't need to be byte-aligned
107         if size_in_bits <= 64 {
108             let mut accumulator = 0u64;
109 
110             // where we are in our accumulation
111             let mut curr_byte_index = self.start_bit_offset / 8;
112             let mut curr_bit_offset = self.start_bit_offset % 8;
113             let mut remaining_bits = size_in_bits;
114 
115             while remaining_bits > 0 {
116                 // how many bits to take from the current byte?
117                 // check if this is the last byte
118                 if curr_bit_offset + remaining_bits <= 8 {
119                     let tmp = ((self.byte_at(curr_byte_index)? >> curr_bit_offset) as u64)
120                         & ((1u64 << remaining_bits) - 1);
121                     accumulator += tmp << (size_in_bits - remaining_bits);
122                     break;
123                 } else {
124                     // this is not the last byte, so we have 8 - curr_bit_offset bits to
125                     // consume in this byte
126                     let bits_to_consume = 8 - curr_bit_offset;
127                     let tmp = (self.byte_at(curr_byte_index)? >> curr_bit_offset) as u64;
128                     accumulator += tmp << (size_in_bits - remaining_bits);
129                     curr_bit_offset = 0;
130                     curr_byte_index += 1;
131                     remaining_bits -= bits_to_consume as usize;
132                 }
133             }
134             T::try_from(accumulator).map_err(|_| ParseError::ArithmeticOverflow)
135         } else {
136             return Err(ParseError::MisalignedPayload);
137         }
138     }
139 
get_size_in_bits(&self) -> usize140     pub fn get_size_in_bits(&self) -> usize {
141         self.end_bit_offset - self.start_bit_offset
142     }
143 }
144 
145 pub trait Packet<'a>
146 where
147     Self: Sized,
148 {
149     type Parent;
150     type Owned;
151     type Builder;
try_parse_from_buffer(buf: impl Into<SizedBitSlice<'a>>) -> Result<Self, ParseError>152     fn try_parse_from_buffer(buf: impl Into<SizedBitSlice<'a>>) -> Result<Self, ParseError>;
try_parse(parent: Self::Parent) -> Result<Self, ParseError>153     fn try_parse(parent: Self::Parent) -> Result<Self, ParseError>;
to_owned_packet(&self) -> Self::Owned154     fn to_owned_packet(&self) -> Self::Owned;
155 }
156 
157 pub trait OwnedPacket
158 where
159     Self: Sized,
160 {
161     // Enable GAT when 1.65 is available in AOSP
162     // type View<'a> where Self : 'a;
try_parse(buf: Box<[u8]>) -> Result<Self, ParseError>163     fn try_parse(buf: Box<[u8]>) -> Result<Self, ParseError>;
164     // fn view<'a>(&'a self) -> Self::View<'a>;
165 }
166 
167 pub trait Builder: Serializable {
168     type OwnedPacket: OwnedPacket;
169 }
170 
171 #[derive(Debug)]
172 pub enum SerializeError {
173     NegativePadding,
174     IntegerConversionFailure,
175     ValueTooLarge,
176     AlignmentError,
177 }
178 
179 pub trait BitWriter {
write_bits<T: Into<u64>>( &mut self, num_bits: usize, gen_contents: impl FnOnce() -> Result<T, SerializeError>, ) -> Result<(), SerializeError>180     fn write_bits<T: Into<u64>>(
181         &mut self,
182         num_bits: usize,
183         gen_contents: impl FnOnce() -> Result<T, SerializeError>,
184     ) -> Result<(), SerializeError>;
185 }
186 
187 pub trait Serializable {
serialize(&self, writer: &mut impl BitWriter) -> Result<(), SerializeError>188     fn serialize(&self, writer: &mut impl BitWriter) -> Result<(), SerializeError>;
189 
size_in_bits(&self) -> Result<usize, SerializeError>190     fn size_in_bits(&self) -> Result<usize, SerializeError> {
191         let mut sizer = Sizer::new();
192         self.serialize(&mut sizer)?;
193         Ok(sizer.size())
194     }
195 
write(&self, vec: &mut Vec<u8>) -> Result<(), SerializeError>196     fn write(&self, vec: &mut Vec<u8>) -> Result<(), SerializeError> {
197         let mut serializer = Serializer::new(vec);
198         self.serialize(&mut serializer)?;
199         serializer.flush();
200         Ok(())
201     }
202 
to_vec(&self) -> Result<Vec<u8>, SerializeError>203     fn to_vec(&self) -> Result<Vec<u8>, SerializeError> {
204         let mut out = vec![];
205         self.write(&mut out)?;
206         Ok(out)
207     }
208 }
209 
210 struct Sizer {
211     size: usize,
212 }
213 
214 impl Sizer {
new() -> Self215     fn new() -> Self {
216         Self { size: 0 }
217     }
218 
size(self) -> usize219     fn size(self) -> usize {
220         self.size
221     }
222 }
223 
224 impl BitWriter for Sizer {
write_bits<T: Into<u64>>( &mut self, num_bits: usize, gen_contents: impl FnOnce() -> Result<T, SerializeError>, ) -> Result<(), SerializeError>225     fn write_bits<T: Into<u64>>(
226         &mut self,
227         num_bits: usize,
228         gen_contents: impl FnOnce() -> Result<T, SerializeError>,
229     ) -> Result<(), SerializeError> {
230         self.size += num_bits;
231         Ok(())
232     }
233 }
234 
235 struct Serializer<'a> {
236     buf: &'a mut Vec<u8>,
237     curr_byte: u8,
238     curr_bit_offset: u8,
239 }
240 
241 impl<'a> Serializer<'a> {
new(buf: &'a mut Vec<u8>) -> Self242     fn new(buf: &'a mut Vec<u8>) -> Self {
243         Self { buf, curr_byte: 0, curr_bit_offset: 0 }
244     }
245 
flush(self)246     fn flush(self) {
247         if self.curr_bit_offset > 0 {
248             // partial byte remaining
249             self.buf.push(self.curr_byte << (8 - self.curr_bit_offset));
250         }
251     }
252 }
253 
254 impl<'a> BitWriter for Serializer<'a> {
write_bits<T: Into<u64>>( &mut self, num_bits: usize, gen_contents: impl FnOnce() -> Result<T, SerializeError>, ) -> Result<(), SerializeError>255     fn write_bits<T: Into<u64>>(
256         &mut self,
257         num_bits: usize,
258         gen_contents: impl FnOnce() -> Result<T, SerializeError>,
259     ) -> Result<(), SerializeError> {
260         let val = gen_contents()?.into();
261 
262         if num_bits < 64 && val >= 1 << num_bits {
263             return Err(SerializeError::ValueTooLarge);
264         }
265 
266         let mut remaining_val = val;
267         let mut remaining_bits = num_bits;
268         while remaining_bits > 0 {
269             let remaining_bits_in_curr_byte = (8 - self.curr_bit_offset) as usize;
270             if remaining_bits < remaining_bits_in_curr_byte {
271                 // we cannot finish the last byte
272                 self.curr_byte += (remaining_val as u8) << self.curr_bit_offset;
273                 self.curr_bit_offset += remaining_bits as u8;
274                 break;
275             } else {
276                 // finish up our current byte and move on
277                 let val_for_this_byte =
278                     (remaining_val & ((1 << remaining_bits_in_curr_byte) - 1)) as u8;
279                 let curr_byte = self.curr_byte + (val_for_this_byte << self.curr_bit_offset);
280                 self.buf.push(curr_byte);
281 
282                 // clear pending byte
283                 self.curr_bit_offset = 0;
284                 self.curr_byte = 0;
285 
286                 // update what's remaining
287                 remaining_val >>= remaining_bits_in_curr_byte;
288                 remaining_bits -= remaining_bits_in_curr_byte;
289             }
290         }
291 
292         Ok(())
293     }
294 }
295