1 //! Slice reader.
2 
3 use crate::{BytesRef, Decode, Error, ErrorKind, Header, Length, Reader, Result, Tag};
4 
5 /// [`Reader`] which consumes an input byte slice.
6 #[derive(Clone, Debug)]
7 pub struct SliceReader<'a> {
8     /// Byte slice being decoded.
9     bytes: BytesRef<'a>,
10 
11     /// Did the decoding operation fail?
12     failed: bool,
13 
14     /// Position within the decoded slice.
15     position: Length,
16 }
17 
18 impl<'a> SliceReader<'a> {
19     /// Create a new slice reader for the given byte slice.
new(bytes: &'a [u8]) -> Result<Self>20     pub fn new(bytes: &'a [u8]) -> Result<Self> {
21         Ok(Self {
22             bytes: BytesRef::new(bytes)?,
23             failed: false,
24             position: Length::ZERO,
25         })
26     }
27 
28     /// Return an error with the given [`ErrorKind`], annotating it with
29     /// context about where the error occurred.
error(&mut self, kind: ErrorKind) -> Error30     pub fn error(&mut self, kind: ErrorKind) -> Error {
31         self.failed = true;
32         kind.at(self.position)
33     }
34 
35     /// Return an error for an invalid value with the given tag.
value_error(&mut self, tag: Tag) -> Error36     pub fn value_error(&mut self, tag: Tag) -> Error {
37         self.error(tag.value_error().kind())
38     }
39 
40     /// Did the decoding operation fail due to an error?
is_failed(&self) -> bool41     pub fn is_failed(&self) -> bool {
42         self.failed
43     }
44 
45     /// Obtain the remaining bytes in this slice reader from the current cursor
46     /// position.
remaining(&self) -> Result<&'a [u8]>47     fn remaining(&self) -> Result<&'a [u8]> {
48         if self.is_failed() {
49             Err(ErrorKind::Failed.at(self.position))
50         } else {
51             self.bytes
52                 .as_slice()
53                 .get(self.position.try_into()?..)
54                 .ok_or_else(|| Error::incomplete(self.input_len()))
55         }
56     }
57 }
58 
59 impl<'a> Reader<'a> for SliceReader<'a> {
input_len(&self) -> Length60     fn input_len(&self) -> Length {
61         self.bytes.len()
62     }
63 
peek_byte(&self) -> Option<u8>64     fn peek_byte(&self) -> Option<u8> {
65         self.remaining()
66             .ok()
67             .and_then(|bytes| bytes.first().cloned())
68     }
69 
peek_header(&self) -> Result<Header>70     fn peek_header(&self) -> Result<Header> {
71         Header::decode(&mut self.clone())
72     }
73 
position(&self) -> Length74     fn position(&self) -> Length {
75         self.position
76     }
77 
read_slice(&mut self, len: Length) -> Result<&'a [u8]>78     fn read_slice(&mut self, len: Length) -> Result<&'a [u8]> {
79         if self.is_failed() {
80             return Err(self.error(ErrorKind::Failed));
81         }
82 
83         match self.remaining()?.get(..len.try_into()?) {
84             Some(result) => {
85                 self.position = (self.position + len)?;
86                 Ok(result)
87             }
88             None => Err(self.error(ErrorKind::Incomplete {
89                 expected_len: (self.position + len)?,
90                 actual_len: self.input_len(),
91             })),
92         }
93     }
94 
decode<T: Decode<'a>>(&mut self) -> Result<T>95     fn decode<T: Decode<'a>>(&mut self) -> Result<T> {
96         if self.is_failed() {
97             return Err(self.error(ErrorKind::Failed));
98         }
99 
100         T::decode(self).map_err(|e| {
101             self.failed = true;
102             e.nested(self.position)
103         })
104     }
105 
error(&mut self, kind: ErrorKind) -> Error106     fn error(&mut self, kind: ErrorKind) -> Error {
107         self.failed = true;
108         kind.at(self.position)
109     }
110 
finish<T>(self, value: T) -> Result<T>111     fn finish<T>(self, value: T) -> Result<T> {
112         if self.is_failed() {
113             Err(ErrorKind::Failed.at(self.position))
114         } else if !self.is_finished() {
115             Err(ErrorKind::TrailingData {
116                 decoded: self.position,
117                 remaining: self.remaining_len(),
118             }
119             .at(self.position))
120         } else {
121             Ok(value)
122         }
123     }
124 
remaining_len(&self) -> Length125     fn remaining_len(&self) -> Length {
126         debug_assert!(self.position <= self.input_len());
127         self.input_len().saturating_sub(self.position)
128     }
129 }
130 
131 #[cfg(test)]
132 mod tests {
133     use super::SliceReader;
134     use crate::{Decode, ErrorKind, Length, Reader, Tag};
135     use hex_literal::hex;
136 
137     // INTEGER: 42
138     const EXAMPLE_MSG: &[u8] = &hex!("02012A00");
139 
140     #[test]
empty_message()141     fn empty_message() {
142         let mut reader = SliceReader::new(&[]).unwrap();
143         let err = bool::decode(&mut reader).err().unwrap();
144         assert_eq!(Some(Length::ZERO), err.position());
145 
146         match err.kind() {
147             ErrorKind::Incomplete {
148                 expected_len,
149                 actual_len,
150             } => {
151                 assert_eq!(actual_len, 0u8.into());
152                 assert_eq!(expected_len, 1u8.into());
153             }
154             other => panic!("unexpected error kind: {:?}", other),
155         }
156     }
157 
158     #[test]
invalid_field_length()159     fn invalid_field_length() {
160         const MSG_LEN: usize = 2;
161 
162         let mut reader = SliceReader::new(&EXAMPLE_MSG[..MSG_LEN]).unwrap();
163         let err = i8::decode(&mut reader).err().unwrap();
164         assert_eq!(Some(Length::from(2u8)), err.position());
165 
166         match err.kind() {
167             ErrorKind::Incomplete {
168                 expected_len,
169                 actual_len,
170             } => {
171                 assert_eq!(actual_len, MSG_LEN.try_into().unwrap());
172                 assert_eq!(expected_len, (MSG_LEN + 1).try_into().unwrap());
173             }
174             other => panic!("unexpected error kind: {:?}", other),
175         }
176     }
177 
178     #[test]
trailing_data()179     fn trailing_data() {
180         let mut reader = SliceReader::new(EXAMPLE_MSG).unwrap();
181         let x = i8::decode(&mut reader).unwrap();
182         assert_eq!(42i8, x);
183 
184         let err = reader.finish(x).err().unwrap();
185         assert_eq!(Some(Length::from(3u8)), err.position());
186 
187         assert_eq!(
188             ErrorKind::TrailingData {
189                 decoded: 3u8.into(),
190                 remaining: 1u8.into()
191             },
192             err.kind()
193         );
194     }
195 
196     #[test]
peek_tag()197     fn peek_tag() {
198         let reader = SliceReader::new(EXAMPLE_MSG).unwrap();
199         assert_eq!(reader.position(), Length::ZERO);
200         assert_eq!(reader.peek_tag().unwrap(), Tag::Integer);
201         assert_eq!(reader.position(), Length::ZERO); // Position unchanged
202     }
203 
204     #[test]
peek_header()205     fn peek_header() {
206         let reader = SliceReader::new(EXAMPLE_MSG).unwrap();
207         assert_eq!(reader.position(), Length::ZERO);
208 
209         let header = reader.peek_header().unwrap();
210         assert_eq!(header.tag, Tag::Integer);
211         assert_eq!(header.length, Length::ONE);
212         assert_eq!(reader.position(), Length::ZERO); // Position unchanged
213     }
214 }
215