1 //! Slice reader. 2 3 use crate::{BytesRef, Decode, Error, ErrorKind, Header, Length, Reader, Result, Tag}; 4 5 /// [`Reader`] which consumes an input byte slice. 6 #[derive(Clone, Debug)] 7 pub struct SliceReader<'a> { 8 /// Byte slice being decoded. 9 bytes: BytesRef<'a>, 10 11 /// Did the decoding operation fail? 12 failed: bool, 13 14 /// Position within the decoded slice. 15 position: Length, 16 } 17 18 impl<'a> SliceReader<'a> { 19 /// Create a new slice reader for the given byte slice. new(bytes: &'a [u8]) -> Result<Self>20 pub fn new(bytes: &'a [u8]) -> Result<Self> { 21 Ok(Self { 22 bytes: BytesRef::new(bytes)?, 23 failed: false, 24 position: Length::ZERO, 25 }) 26 } 27 28 /// Return an error with the given [`ErrorKind`], annotating it with 29 /// context about where the error occurred. error(&mut self, kind: ErrorKind) -> Error30 pub fn error(&mut self, kind: ErrorKind) -> Error { 31 self.failed = true; 32 kind.at(self.position) 33 } 34 35 /// Return an error for an invalid value with the given tag. value_error(&mut self, tag: Tag) -> Error36 pub fn value_error(&mut self, tag: Tag) -> Error { 37 self.error(tag.value_error().kind()) 38 } 39 40 /// Did the decoding operation fail due to an error? is_failed(&self) -> bool41 pub fn is_failed(&self) -> bool { 42 self.failed 43 } 44 45 /// Obtain the remaining bytes in this slice reader from the current cursor 46 /// position. remaining(&self) -> Result<&'a [u8]>47 fn remaining(&self) -> Result<&'a [u8]> { 48 if self.is_failed() { 49 Err(ErrorKind::Failed.at(self.position)) 50 } else { 51 self.bytes 52 .as_slice() 53 .get(self.position.try_into()?..) 54 .ok_or_else(|| Error::incomplete(self.input_len())) 55 } 56 } 57 } 58 59 impl<'a> Reader<'a> for SliceReader<'a> { input_len(&self) -> Length60 fn input_len(&self) -> Length { 61 self.bytes.len() 62 } 63 peek_byte(&self) -> Option<u8>64 fn peek_byte(&self) -> Option<u8> { 65 self.remaining() 66 .ok() 67 .and_then(|bytes| bytes.first().cloned()) 68 } 69 peek_header(&self) -> Result<Header>70 fn peek_header(&self) -> Result<Header> { 71 Header::decode(&mut self.clone()) 72 } 73 position(&self) -> Length74 fn position(&self) -> Length { 75 self.position 76 } 77 read_slice(&mut self, len: Length) -> Result<&'a [u8]>78 fn read_slice(&mut self, len: Length) -> Result<&'a [u8]> { 79 if self.is_failed() { 80 return Err(self.error(ErrorKind::Failed)); 81 } 82 83 match self.remaining()?.get(..len.try_into()?) { 84 Some(result) => { 85 self.position = (self.position + len)?; 86 Ok(result) 87 } 88 None => Err(self.error(ErrorKind::Incomplete { 89 expected_len: (self.position + len)?, 90 actual_len: self.input_len(), 91 })), 92 } 93 } 94 decode<T: Decode<'a>>(&mut self) -> Result<T>95 fn decode<T: Decode<'a>>(&mut self) -> Result<T> { 96 if self.is_failed() { 97 return Err(self.error(ErrorKind::Failed)); 98 } 99 100 T::decode(self).map_err(|e| { 101 self.failed = true; 102 e.nested(self.position) 103 }) 104 } 105 error(&mut self, kind: ErrorKind) -> Error106 fn error(&mut self, kind: ErrorKind) -> Error { 107 self.failed = true; 108 kind.at(self.position) 109 } 110 finish<T>(self, value: T) -> Result<T>111 fn finish<T>(self, value: T) -> Result<T> { 112 if self.is_failed() { 113 Err(ErrorKind::Failed.at(self.position)) 114 } else if !self.is_finished() { 115 Err(ErrorKind::TrailingData { 116 decoded: self.position, 117 remaining: self.remaining_len(), 118 } 119 .at(self.position)) 120 } else { 121 Ok(value) 122 } 123 } 124 remaining_len(&self) -> Length125 fn remaining_len(&self) -> Length { 126 debug_assert!(self.position <= self.input_len()); 127 self.input_len().saturating_sub(self.position) 128 } 129 } 130 131 #[cfg(test)] 132 mod tests { 133 use super::SliceReader; 134 use crate::{Decode, ErrorKind, Length, Reader, Tag}; 135 use hex_literal::hex; 136 137 // INTEGER: 42 138 const EXAMPLE_MSG: &[u8] = &hex!("02012A00"); 139 140 #[test] empty_message()141 fn empty_message() { 142 let mut reader = SliceReader::new(&[]).unwrap(); 143 let err = bool::decode(&mut reader).err().unwrap(); 144 assert_eq!(Some(Length::ZERO), err.position()); 145 146 match err.kind() { 147 ErrorKind::Incomplete { 148 expected_len, 149 actual_len, 150 } => { 151 assert_eq!(actual_len, 0u8.into()); 152 assert_eq!(expected_len, 1u8.into()); 153 } 154 other => panic!("unexpected error kind: {:?}", other), 155 } 156 } 157 158 #[test] invalid_field_length()159 fn invalid_field_length() { 160 const MSG_LEN: usize = 2; 161 162 let mut reader = SliceReader::new(&EXAMPLE_MSG[..MSG_LEN]).unwrap(); 163 let err = i8::decode(&mut reader).err().unwrap(); 164 assert_eq!(Some(Length::from(2u8)), err.position()); 165 166 match err.kind() { 167 ErrorKind::Incomplete { 168 expected_len, 169 actual_len, 170 } => { 171 assert_eq!(actual_len, MSG_LEN.try_into().unwrap()); 172 assert_eq!(expected_len, (MSG_LEN + 1).try_into().unwrap()); 173 } 174 other => panic!("unexpected error kind: {:?}", other), 175 } 176 } 177 178 #[test] trailing_data()179 fn trailing_data() { 180 let mut reader = SliceReader::new(EXAMPLE_MSG).unwrap(); 181 let x = i8::decode(&mut reader).unwrap(); 182 assert_eq!(42i8, x); 183 184 let err = reader.finish(x).err().unwrap(); 185 assert_eq!(Some(Length::from(3u8)), err.position()); 186 187 assert_eq!( 188 ErrorKind::TrailingData { 189 decoded: 3u8.into(), 190 remaining: 1u8.into() 191 }, 192 err.kind() 193 ); 194 } 195 196 #[test] peek_tag()197 fn peek_tag() { 198 let reader = SliceReader::new(EXAMPLE_MSG).unwrap(); 199 assert_eq!(reader.position(), Length::ZERO); 200 assert_eq!(reader.peek_tag().unwrap(), Tag::Integer); 201 assert_eq!(reader.position(), Length::ZERO); // Position unchanged 202 } 203 204 #[test] peek_header()205 fn peek_header() { 206 let reader = SliceReader::new(EXAMPLE_MSG).unwrap(); 207 assert_eq!(reader.position(), Length::ZERO); 208 209 let header = reader.peek_header().unwrap(); 210 assert_eq!(header.tag, Tag::Integer); 211 assert_eq!(header.length, Length::ONE); 212 assert_eq!(reader.position(), Length::ZERO); // Position unchanged 213 } 214 } 215