1 use error::Result;
2 use serde;
3 use std::io;
4 
5 /// An optional Read trait for advanced Bincode usage.
6 ///
7 /// It is highly recommended to use bincode with `io::Read` or `&[u8]` before
8 /// implementing a custom `BincodeRead`.
9 ///
10 /// The forward_read_* methods are necessary because some byte sources want
11 /// to pass a long-lived borrow to the visitor and others want to pass a
12 /// transient slice.
13 pub trait BincodeRead<'storage>: io::Read {
14     /// Check that the next `length` bytes are a valid string and pass
15     /// it on to the serde reader.
forward_read_str<V>(&mut self, length: usize, visitor: V) -> Result<V::Value> where V: serde::de::Visitor<'storage>16     fn forward_read_str<V>(&mut self, length: usize, visitor: V) -> Result<V::Value>
17     where
18         V: serde::de::Visitor<'storage>;
19 
20     /// Transfer ownership of the next `length` bytes to the caller.
get_byte_buffer(&mut self, length: usize) -> Result<Vec<u8>>21     fn get_byte_buffer(&mut self, length: usize) -> Result<Vec<u8>>;
22 
23     /// Pass a slice of the next `length` bytes on to the serde reader.
forward_read_bytes<V>(&mut self, length: usize, visitor: V) -> Result<V::Value> where V: serde::de::Visitor<'storage>24     fn forward_read_bytes<V>(&mut self, length: usize, visitor: V) -> Result<V::Value>
25     where
26         V: serde::de::Visitor<'storage>;
27 }
28 
29 /// A BincodeRead implementation for byte slices
30 pub struct SliceReader<'storage> {
31     slice: &'storage [u8],
32 }
33 
34 /// A BincodeRead implementation for `io::Read`ers
35 pub struct IoReader<R> {
36     reader: R,
37     temp_buffer: Vec<u8>,
38 }
39 
40 impl<'storage> SliceReader<'storage> {
41     /// Constructs a slice reader
new(bytes: &'storage [u8]) -> SliceReader<'storage>42     pub(crate) fn new(bytes: &'storage [u8]) -> SliceReader<'storage> {
43         SliceReader { slice: bytes }
44     }
45 
46     #[inline(always)]
get_byte_slice(&mut self, length: usize) -> Result<&'storage [u8]>47     fn get_byte_slice(&mut self, length: usize) -> Result<&'storage [u8]> {
48         if length > self.slice.len() {
49             return Err(SliceReader::unexpected_eof());
50         }
51         let (read_slice, remaining) = self.slice.split_at(length);
52         self.slice = remaining;
53         Ok(read_slice)
54     }
55 
is_finished(&self) -> bool56     pub(crate) fn is_finished(&self) -> bool {
57         self.slice.is_empty()
58     }
59 }
60 
61 impl<R> IoReader<R> {
62     /// Constructs an IoReadReader
new(r: R) -> IoReader<R>63     pub(crate) fn new(r: R) -> IoReader<R> {
64         IoReader {
65             reader: r,
66             temp_buffer: vec![],
67         }
68     }
69 }
70 
71 impl<'storage> io::Read for SliceReader<'storage> {
72     #[inline(always)]
read(&mut self, out: &mut [u8]) -> io::Result<usize>73     fn read(&mut self, out: &mut [u8]) -> io::Result<usize> {
74         if out.len() > self.slice.len() {
75             return Err(io::ErrorKind::UnexpectedEof.into());
76         }
77         let (read_slice, remaining) = self.slice.split_at(out.len());
78         out.copy_from_slice(read_slice);
79         self.slice = remaining;
80 
81         Ok(out.len())
82     }
83 
84     #[inline(always)]
read_exact(&mut self, out: &mut [u8]) -> io::Result<()>85     fn read_exact(&mut self, out: &mut [u8]) -> io::Result<()> {
86         self.read(out).map(|_| ())
87     }
88 }
89 
90 impl<R: io::Read> io::Read for IoReader<R> {
91     #[inline(always)]
read(&mut self, out: &mut [u8]) -> io::Result<usize>92     fn read(&mut self, out: &mut [u8]) -> io::Result<usize> {
93         self.reader.read(out)
94     }
95     #[inline(always)]
read_exact(&mut self, out: &mut [u8]) -> io::Result<()>96     fn read_exact(&mut self, out: &mut [u8]) -> io::Result<()> {
97         self.reader.read_exact(out)
98     }
99 }
100 
101 impl<'storage> SliceReader<'storage> {
102     #[inline(always)]
unexpected_eof() -> Box<::ErrorKind>103     fn unexpected_eof() -> Box<::ErrorKind> {
104         Box::new(::ErrorKind::Io(io::Error::new(
105             io::ErrorKind::UnexpectedEof,
106             "",
107         )))
108     }
109 }
110 
111 impl<'storage> BincodeRead<'storage> for SliceReader<'storage> {
112     #[inline(always)]
forward_read_str<V>(&mut self, length: usize, visitor: V) -> Result<V::Value> where V: serde::de::Visitor<'storage>,113     fn forward_read_str<V>(&mut self, length: usize, visitor: V) -> Result<V::Value>
114     where
115         V: serde::de::Visitor<'storage>,
116     {
117         use ErrorKind;
118         let string = match ::std::str::from_utf8(self.get_byte_slice(length)?) {
119             Ok(s) => s,
120             Err(e) => return Err(ErrorKind::InvalidUtf8Encoding(e).into()),
121         };
122         visitor.visit_borrowed_str(string)
123     }
124 
125     #[inline(always)]
get_byte_buffer(&mut self, length: usize) -> Result<Vec<u8>>126     fn get_byte_buffer(&mut self, length: usize) -> Result<Vec<u8>> {
127         self.get_byte_slice(length).map(|x| x.to_vec())
128     }
129 
130     #[inline(always)]
forward_read_bytes<V>(&mut self, length: usize, visitor: V) -> Result<V::Value> where V: serde::de::Visitor<'storage>,131     fn forward_read_bytes<V>(&mut self, length: usize, visitor: V) -> Result<V::Value>
132     where
133         V: serde::de::Visitor<'storage>,
134     {
135         visitor.visit_borrowed_bytes(self.get_byte_slice(length)?)
136     }
137 }
138 
139 impl<R> IoReader<R>
140 where
141     R: io::Read,
142 {
fill_buffer(&mut self, length: usize) -> Result<()>143     fn fill_buffer(&mut self, length: usize) -> Result<()> {
144         self.temp_buffer.resize(length, 0);
145 
146         self.reader.read_exact(&mut self.temp_buffer)?;
147 
148         Ok(())
149     }
150 }
151 
152 impl<'a, R> BincodeRead<'a> for IoReader<R>
153 where
154     R: io::Read,
155 {
forward_read_str<V>(&mut self, length: usize, visitor: V) -> Result<V::Value> where V: serde::de::Visitor<'a>,156     fn forward_read_str<V>(&mut self, length: usize, visitor: V) -> Result<V::Value>
157     where
158         V: serde::de::Visitor<'a>,
159     {
160         self.fill_buffer(length)?;
161 
162         let string = match ::std::str::from_utf8(&self.temp_buffer[..]) {
163             Ok(s) => s,
164             Err(e) => return Err(::ErrorKind::InvalidUtf8Encoding(e).into()),
165         };
166 
167         visitor.visit_str(string)
168     }
169 
get_byte_buffer(&mut self, length: usize) -> Result<Vec<u8>>170     fn get_byte_buffer(&mut self, length: usize) -> Result<Vec<u8>> {
171         self.fill_buffer(length)?;
172         Ok(::std::mem::replace(&mut self.temp_buffer, Vec::new()))
173     }
174 
forward_read_bytes<V>(&mut self, length: usize, visitor: V) -> Result<V::Value> where V: serde::de::Visitor<'a>,175     fn forward_read_bytes<V>(&mut self, length: usize, visitor: V) -> Result<V::Value>
176     where
177         V: serde::de::Visitor<'a>,
178     {
179         self.fill_buffer(length)?;
180         visitor.visit_bytes(&self.temp_buffer[..])
181     }
182 }
183 
184 #[cfg(test)]
185 mod test {
186     use super::IoReader;
187 
188     #[test]
test_fill_buffer()189     fn test_fill_buffer() {
190         let buffer = vec![0u8; 64];
191         let mut reader = IoReader::new(buffer.as_slice());
192 
193         reader.fill_buffer(20).unwrap();
194         assert_eq!(20, reader.temp_buffer.len());
195 
196         reader.fill_buffer(30).unwrap();
197         assert_eq!(30, reader.temp_buffer.len());
198 
199         reader.fill_buffer(5).unwrap();
200         assert_eq!(5, reader.temp_buffer.len());
201     }
202 }
203