1 // Copyright 2023 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 use std::convert::TryFrom; 16 use std::convert::TryInto; 17 use std::ops::Deref; 18 19 #[derive(Debug)] 20 pub enum ParseError { 21 InvalidEnumValue, 22 DivisionFailure, 23 ArithmeticOverflow, 24 OutOfBoundsAccess, 25 MisalignedPayload, 26 } 27 28 #[derive(Clone, Copy, Debug)] 29 pub struct BitSlice<'a> { 30 // note: the offsets are ENTIRELY UNRELATED to the size of this struct, 31 // so indexing needs to be checked to avoid panics 32 backing: &'a [u8], 33 34 // invariant: end_bit_offset >= start_bit_offset, so subtraction will NEVER wrap 35 start_bit_offset: usize, 36 end_bit_offset: usize, 37 } 38 39 #[derive(Clone, Copy, Debug)] 40 pub struct SizedBitSlice<'a>(BitSlice<'a>); 41 42 impl<'a> BitSlice<'a> { offset(&self, offset: usize) -> Result<BitSlice<'a>, ParseError>43 pub fn offset(&self, offset: usize) -> Result<BitSlice<'a>, ParseError> { 44 if self.end_bit_offset - self.start_bit_offset < offset { 45 return Err(ParseError::OutOfBoundsAccess); 46 } 47 Ok(Self { 48 backing: self.backing, 49 start_bit_offset: self 50 .start_bit_offset 51 .checked_add(offset) 52 .ok_or(ParseError::ArithmeticOverflow)?, 53 end_bit_offset: self.end_bit_offset, 54 }) 55 } 56 slice(&self, len: usize) -> Result<SizedBitSlice<'a>, ParseError>57 pub fn slice(&self, len: usize) -> Result<SizedBitSlice<'a>, ParseError> { 58 if self.end_bit_offset - self.start_bit_offset < len { 59 return Err(ParseError::OutOfBoundsAccess); 60 } 61 Ok(SizedBitSlice(Self { 62 backing: self.backing, 63 start_bit_offset: self.start_bit_offset, 64 end_bit_offset: self 65 .start_bit_offset 66 .checked_add(len) 67 .ok_or(ParseError::ArithmeticOverflow)?, 68 })) 69 } 70 byte_at(&self, index: usize) -> Result<u8, ParseError>71 fn byte_at(&self, index: usize) -> Result<u8, ParseError> { 72 self.backing.get(index).ok_or(ParseError::OutOfBoundsAccess).copied() 73 } 74 } 75 76 impl<'a> Deref for SizedBitSlice<'a> { 77 type Target = BitSlice<'a>; 78 deref(&self) -> &Self::Target79 fn deref(&self) -> &Self::Target { 80 &self.0 81 } 82 } 83 84 impl<'a> From<SizedBitSlice<'a>> for BitSlice<'a> { from(x: SizedBitSlice<'a>) -> Self85 fn from(x: SizedBitSlice<'a>) -> Self { 86 *x 87 } 88 } 89 90 impl<'a, 'b> From<&'b [u8]> for SizedBitSlice<'a> 91 where 92 'b: 'a, 93 { from(backing: &'a [u8]) -> Self94 fn from(backing: &'a [u8]) -> Self { 95 Self(BitSlice { backing, start_bit_offset: 0, end_bit_offset: backing.len() * 8 }) 96 } 97 } 98 99 impl<'a> SizedBitSlice<'a> { try_parse<T: TryFrom<u64>>(&self) -> Result<T, ParseError>100 pub fn try_parse<T: TryFrom<u64>>(&self) -> Result<T, ParseError> { 101 if self.end_bit_offset < self.start_bit_offset { 102 return Err(ParseError::OutOfBoundsAccess); 103 } 104 let size_in_bits = self.end_bit_offset - self.start_bit_offset; 105 106 // fields that fit into a u64 don't need to be byte-aligned 107 if size_in_bits <= 64 { 108 let mut accumulator = 0u64; 109 110 // where we are in our accumulation 111 let mut curr_byte_index = self.start_bit_offset / 8; 112 let mut curr_bit_offset = self.start_bit_offset % 8; 113 let mut remaining_bits = size_in_bits; 114 115 while remaining_bits > 0 { 116 // how many bits to take from the current byte? 117 // check if this is the last byte 118 if curr_bit_offset + remaining_bits <= 8 { 119 let tmp = ((self.byte_at(curr_byte_index)? >> curr_bit_offset) as u64) 120 & ((1u64 << remaining_bits) - 1); 121 accumulator += tmp << (size_in_bits - remaining_bits); 122 break; 123 } else { 124 // this is not the last byte, so we have 8 - curr_bit_offset bits to 125 // consume in this byte 126 let bits_to_consume = 8 - curr_bit_offset; 127 let tmp = (self.byte_at(curr_byte_index)? >> curr_bit_offset) as u64; 128 accumulator += tmp << (size_in_bits - remaining_bits); 129 curr_bit_offset = 0; 130 curr_byte_index += 1; 131 remaining_bits -= bits_to_consume as usize; 132 } 133 } 134 T::try_from(accumulator).map_err(|_| ParseError::ArithmeticOverflow) 135 } else { 136 return Err(ParseError::MisalignedPayload); 137 } 138 } 139 get_size_in_bits(&self) -> usize140 pub fn get_size_in_bits(&self) -> usize { 141 self.end_bit_offset - self.start_bit_offset 142 } 143 } 144 145 pub trait Packet<'a> 146 where 147 Self: Sized, 148 { 149 type Parent; 150 type Owned; 151 type Builder; try_parse_from_buffer(buf: impl Into<SizedBitSlice<'a>>) -> Result<Self, ParseError>152 fn try_parse_from_buffer(buf: impl Into<SizedBitSlice<'a>>) -> Result<Self, ParseError>; try_parse(parent: Self::Parent) -> Result<Self, ParseError>153 fn try_parse(parent: Self::Parent) -> Result<Self, ParseError>; to_owned_packet(&self) -> Self::Owned154 fn to_owned_packet(&self) -> Self::Owned; 155 } 156 157 pub trait OwnedPacket 158 where 159 Self: Sized, 160 { 161 // Enable GAT when 1.65 is available in AOSP 162 // type View<'a> where Self : 'a; try_parse(buf: Box<[u8]>) -> Result<Self, ParseError>163 fn try_parse(buf: Box<[u8]>) -> Result<Self, ParseError>; 164 // fn view<'a>(&'a self) -> Self::View<'a>; 165 } 166 167 pub trait Builder: Serializable { 168 type OwnedPacket: OwnedPacket; 169 } 170 171 #[derive(Debug)] 172 pub enum SerializeError { 173 NegativePadding, 174 IntegerConversionFailure, 175 ValueTooLarge, 176 AlignmentError, 177 } 178 179 pub trait BitWriter { write_bits<T: Into<u64>>( &mut self, num_bits: usize, gen_contents: impl FnOnce() -> Result<T, SerializeError>, ) -> Result<(), SerializeError>180 fn write_bits<T: Into<u64>>( 181 &mut self, 182 num_bits: usize, 183 gen_contents: impl FnOnce() -> Result<T, SerializeError>, 184 ) -> Result<(), SerializeError>; 185 } 186 187 pub trait Serializable { serialize(&self, writer: &mut impl BitWriter) -> Result<(), SerializeError>188 fn serialize(&self, writer: &mut impl BitWriter) -> Result<(), SerializeError>; 189 size_in_bits(&self) -> Result<usize, SerializeError>190 fn size_in_bits(&self) -> Result<usize, SerializeError> { 191 let mut sizer = Sizer::new(); 192 self.serialize(&mut sizer)?; 193 Ok(sizer.size()) 194 } 195 write(&self, vec: &mut Vec<u8>) -> Result<(), SerializeError>196 fn write(&self, vec: &mut Vec<u8>) -> Result<(), SerializeError> { 197 let mut serializer = Serializer::new(vec); 198 self.serialize(&mut serializer)?; 199 serializer.flush(); 200 Ok(()) 201 } 202 to_vec(&self) -> Result<Vec<u8>, SerializeError>203 fn to_vec(&self) -> Result<Vec<u8>, SerializeError> { 204 let mut out = vec![]; 205 self.write(&mut out)?; 206 Ok(out) 207 } 208 } 209 210 struct Sizer { 211 size: usize, 212 } 213 214 impl Sizer { new() -> Self215 fn new() -> Self { 216 Self { size: 0 } 217 } 218 size(self) -> usize219 fn size(self) -> usize { 220 self.size 221 } 222 } 223 224 impl BitWriter for Sizer { write_bits<T: Into<u64>>( &mut self, num_bits: usize, gen_contents: impl FnOnce() -> Result<T, SerializeError>, ) -> Result<(), SerializeError>225 fn write_bits<T: Into<u64>>( 226 &mut self, 227 num_bits: usize, 228 gen_contents: impl FnOnce() -> Result<T, SerializeError>, 229 ) -> Result<(), SerializeError> { 230 self.size += num_bits; 231 Ok(()) 232 } 233 } 234 235 struct Serializer<'a> { 236 buf: &'a mut Vec<u8>, 237 curr_byte: u8, 238 curr_bit_offset: u8, 239 } 240 241 impl<'a> Serializer<'a> { new(buf: &'a mut Vec<u8>) -> Self242 fn new(buf: &'a mut Vec<u8>) -> Self { 243 Self { buf, curr_byte: 0, curr_bit_offset: 0 } 244 } 245 flush(self)246 fn flush(self) { 247 if self.curr_bit_offset > 0 { 248 // partial byte remaining 249 self.buf.push(self.curr_byte << (8 - self.curr_bit_offset)); 250 } 251 } 252 } 253 254 impl<'a> BitWriter for Serializer<'a> { write_bits<T: Into<u64>>( &mut self, num_bits: usize, gen_contents: impl FnOnce() -> Result<T, SerializeError>, ) -> Result<(), SerializeError>255 fn write_bits<T: Into<u64>>( 256 &mut self, 257 num_bits: usize, 258 gen_contents: impl FnOnce() -> Result<T, SerializeError>, 259 ) -> Result<(), SerializeError> { 260 let val = gen_contents()?.into(); 261 262 if num_bits < 64 && val >= 1 << num_bits { 263 return Err(SerializeError::ValueTooLarge); 264 } 265 266 let mut remaining_val = val; 267 let mut remaining_bits = num_bits; 268 while remaining_bits > 0 { 269 let remaining_bits_in_curr_byte = (8 - self.curr_bit_offset) as usize; 270 if remaining_bits < remaining_bits_in_curr_byte { 271 // we cannot finish the last byte 272 self.curr_byte += (remaining_val as u8) << self.curr_bit_offset; 273 self.curr_bit_offset += remaining_bits as u8; 274 break; 275 } else { 276 // finish up our current byte and move on 277 let val_for_this_byte = 278 (remaining_val & ((1 << remaining_bits_in_curr_byte) - 1)) as u8; 279 let curr_byte = self.curr_byte + (val_for_this_byte << self.curr_bit_offset); 280 self.buf.push(curr_byte); 281 282 // clear pending byte 283 self.curr_bit_offset = 0; 284 self.curr_byte = 0; 285 286 // update what's remaining 287 remaining_val >>= remaining_bits_in_curr_byte; 288 remaining_bits -= remaining_bits_in_curr_byte; 289 } 290 } 291 292 Ok(()) 293 } 294 } 295