1 //! Utilities to work with raw WebSocket frames.
2 
3 pub mod coding;
4 
5 #[allow(clippy::module_inception)]
6 mod frame;
7 mod mask;
8 
9 use crate::{
10     error::{CapacityError, Error, Result},
11     Message, ReadBuffer,
12 };
13 use log::*;
14 use std::io::{Error as IoError, ErrorKind as IoErrorKind, Read, Write};
15 
16 pub use self::frame::{CloseFrame, Frame, FrameHeader};
17 
18 /// A reader and writer for WebSocket frames.
19 #[derive(Debug)]
20 pub struct FrameSocket<Stream> {
21     /// The underlying network stream.
22     stream: Stream,
23     /// Codec for reading/writing frames.
24     codec: FrameCodec,
25 }
26 
27 impl<Stream> FrameSocket<Stream> {
28     /// Create a new frame socket.
new(stream: Stream) -> Self29     pub fn new(stream: Stream) -> Self {
30         FrameSocket { stream, codec: FrameCodec::new() }
31     }
32 
33     /// Create a new frame socket from partially read data.
from_partially_read(stream: Stream, part: Vec<u8>) -> Self34     pub fn from_partially_read(stream: Stream, part: Vec<u8>) -> Self {
35         FrameSocket { stream, codec: FrameCodec::from_partially_read(part) }
36     }
37 
38     /// Extract a stream from the socket.
into_inner(self) -> (Stream, Vec<u8>)39     pub fn into_inner(self) -> (Stream, Vec<u8>) {
40         (self.stream, self.codec.in_buffer.into_vec())
41     }
42 
43     /// Returns a shared reference to the inner stream.
get_ref(&self) -> &Stream44     pub fn get_ref(&self) -> &Stream {
45         &self.stream
46     }
47 
48     /// Returns a mutable reference to the inner stream.
get_mut(&mut self) -> &mut Stream49     pub fn get_mut(&mut self) -> &mut Stream {
50         &mut self.stream
51     }
52 }
53 
54 impl<Stream> FrameSocket<Stream>
55 where
56     Stream: Read,
57 {
58     /// Read a frame from stream.
read(&mut self, max_size: Option<usize>) -> Result<Option<Frame>>59     pub fn read(&mut self, max_size: Option<usize>) -> Result<Option<Frame>> {
60         self.codec.read_frame(&mut self.stream, max_size)
61     }
62 }
63 
64 impl<Stream> FrameSocket<Stream>
65 where
66     Stream: Write,
67 {
68     /// Writes and immediately flushes a frame.
69     /// Equivalent to calling [`write`](Self::write) then [`flush`](Self::flush).
send(&mut self, frame: Frame) -> Result<()>70     pub fn send(&mut self, frame: Frame) -> Result<()> {
71         self.write(frame)?;
72         self.flush()
73     }
74 
75     /// Write a frame to stream.
76     ///
77     /// A subsequent call should be made to [`flush`](Self::flush) to flush writes.
78     ///
79     /// This function guarantees that the frame is queued unless [`Error::WriteBufferFull`]
80     /// is returned.
81     /// In order to handle WouldBlock or Incomplete, call [`flush`](Self::flush) afterwards.
write(&mut self, frame: Frame) -> Result<()>82     pub fn write(&mut self, frame: Frame) -> Result<()> {
83         self.codec.buffer_frame(&mut self.stream, frame)
84     }
85 
86     /// Flush writes.
flush(&mut self) -> Result<()>87     pub fn flush(&mut self) -> Result<()> {
88         self.codec.write_out_buffer(&mut self.stream)?;
89         Ok(self.stream.flush()?)
90     }
91 }
92 
93 /// A codec for WebSocket frames.
94 #[derive(Debug)]
95 pub(super) struct FrameCodec {
96     /// Buffer to read data from the stream.
97     in_buffer: ReadBuffer,
98     /// Buffer to send packets to the network.
99     out_buffer: Vec<u8>,
100     /// Capacity limit for `out_buffer`.
101     max_out_buffer_len: usize,
102     /// Buffer target length to reach before writing to the stream
103     /// on calls to `buffer_frame`.
104     ///
105     /// Setting this to non-zero will buffer small writes from hitting
106     /// the stream.
107     out_buffer_write_len: usize,
108     /// Header and remaining size of the incoming packet being processed.
109     header: Option<(FrameHeader, u64)>,
110 }
111 
112 impl FrameCodec {
113     /// Create a new frame codec.
new() -> Self114     pub(super) fn new() -> Self {
115         Self {
116             in_buffer: ReadBuffer::new(),
117             out_buffer: Vec::new(),
118             max_out_buffer_len: usize::MAX,
119             out_buffer_write_len: 0,
120             header: None,
121         }
122     }
123 
124     /// Create a new frame codec from partially read data.
from_partially_read(part: Vec<u8>) -> Self125     pub(super) fn from_partially_read(part: Vec<u8>) -> Self {
126         Self {
127             in_buffer: ReadBuffer::from_partially_read(part),
128             out_buffer: Vec::new(),
129             max_out_buffer_len: usize::MAX,
130             out_buffer_write_len: 0,
131             header: None,
132         }
133     }
134 
135     /// Sets a maximum size for the out buffer.
set_max_out_buffer_len(&mut self, max: usize)136     pub(super) fn set_max_out_buffer_len(&mut self, max: usize) {
137         self.max_out_buffer_len = max;
138     }
139 
140     /// Sets [`Self::buffer_frame`] buffer target length to reach before
141     /// writing to the stream.
set_out_buffer_write_len(&mut self, len: usize)142     pub(super) fn set_out_buffer_write_len(&mut self, len: usize) {
143         self.out_buffer_write_len = len;
144     }
145 
146     /// Read a frame from the provided stream.
read_frame<Stream>( &mut self, stream: &mut Stream, max_size: Option<usize>, ) -> Result<Option<Frame>> where Stream: Read,147     pub(super) fn read_frame<Stream>(
148         &mut self,
149         stream: &mut Stream,
150         max_size: Option<usize>,
151     ) -> Result<Option<Frame>>
152     where
153         Stream: Read,
154     {
155         let max_size = max_size.unwrap_or_else(usize::max_value);
156 
157         let payload = loop {
158             {
159                 let cursor = self.in_buffer.as_cursor_mut();
160 
161                 if self.header.is_none() {
162                     self.header = FrameHeader::parse(cursor)?;
163                 }
164 
165                 if let Some((_, ref length)) = self.header {
166                     let length = *length;
167 
168                     // Enforce frame size limit early and make sure `length`
169                     // is not too big (fits into `usize`).
170                     if length > max_size as u64 {
171                         return Err(Error::Capacity(CapacityError::MessageTooLong {
172                             size: length as usize,
173                             max_size,
174                         }));
175                     }
176 
177                     let input_size = cursor.get_ref().len() as u64 - cursor.position();
178                     if length <= input_size {
179                         // No truncation here since `length` is checked above
180                         let mut payload = Vec::with_capacity(length as usize);
181                         if length > 0 {
182                             cursor.take(length).read_to_end(&mut payload)?;
183                         }
184                         break payload;
185                     }
186                 }
187             }
188 
189             // Not enough data in buffer.
190             let size = self.in_buffer.read_from(stream)?;
191             if size == 0 {
192                 trace!("no frame received");
193                 return Ok(None);
194             }
195         };
196 
197         let (header, length) = self.header.take().expect("Bug: no frame header");
198         debug_assert_eq!(payload.len() as u64, length);
199         let frame = Frame::from_payload(header, payload);
200         trace!("received frame {}", frame);
201         Ok(Some(frame))
202     }
203 
204     /// Writes a frame into the `out_buffer`.
205     /// If the out buffer size is over the `out_buffer_write_len` will also write
206     /// the out buffer into the provided `stream`.
207     ///
208     /// To ensure buffered frames are written call [`Self::write_out_buffer`].
209     ///
210     /// May write to the stream, will **not** flush.
buffer_frame<Stream>(&mut self, stream: &mut Stream, frame: Frame) -> Result<()> where Stream: Write,211     pub(super) fn buffer_frame<Stream>(&mut self, stream: &mut Stream, frame: Frame) -> Result<()>
212     where
213         Stream: Write,
214     {
215         if frame.len() + self.out_buffer.len() > self.max_out_buffer_len {
216             return Err(Error::WriteBufferFull(Message::Frame(frame)));
217         }
218 
219         trace!("writing frame {}", frame);
220 
221         self.out_buffer.reserve(frame.len());
222         frame.format(&mut self.out_buffer).expect("Bug: can't write to vector");
223 
224         if self.out_buffer.len() > self.out_buffer_write_len {
225             self.write_out_buffer(stream)
226         } else {
227             Ok(())
228         }
229     }
230 
231     /// Writes the out_buffer to the provided stream.
232     ///
233     /// Does **not** flush.
write_out_buffer<Stream>(&mut self, stream: &mut Stream) -> Result<()> where Stream: Write,234     pub(super) fn write_out_buffer<Stream>(&mut self, stream: &mut Stream) -> Result<()>
235     where
236         Stream: Write,
237     {
238         while !self.out_buffer.is_empty() {
239             let len = stream.write(&self.out_buffer)?;
240             if len == 0 {
241                 // This is the same as "Connection reset by peer"
242                 return Err(IoError::new(
243                     IoErrorKind::ConnectionReset,
244                     "Connection reset while sending",
245                 )
246                 .into());
247             }
248             self.out_buffer.drain(0..len);
249         }
250 
251         Ok(())
252     }
253 }
254 
255 #[cfg(test)]
256 mod tests {
257 
258     use crate::error::{CapacityError, Error};
259 
260     use super::{Frame, FrameSocket};
261 
262     use std::io::Cursor;
263 
264     #[test]
read_frames()265     fn read_frames() {
266         let raw = Cursor::new(vec![
267             0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x82, 0x03, 0x03, 0x02, 0x01,
268             0x99,
269         ]);
270         let mut sock = FrameSocket::new(raw);
271 
272         assert_eq!(
273             sock.read(None).unwrap().unwrap().into_data(),
274             vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
275         );
276         assert_eq!(sock.read(None).unwrap().unwrap().into_data(), vec![0x03, 0x02, 0x01]);
277         assert!(sock.read(None).unwrap().is_none());
278 
279         let (_, rest) = sock.into_inner();
280         assert_eq!(rest, vec![0x99]);
281     }
282 
283     #[test]
from_partially_read()284     fn from_partially_read() {
285         let raw = Cursor::new(vec![0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
286         let mut sock = FrameSocket::from_partially_read(raw, vec![0x82, 0x07, 0x01]);
287         assert_eq!(
288             sock.read(None).unwrap().unwrap().into_data(),
289             vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]
290         );
291     }
292 
293     #[test]
write_frames()294     fn write_frames() {
295         let mut sock = FrameSocket::new(Vec::new());
296 
297         let frame = Frame::ping(vec![0x04, 0x05]);
298         sock.send(frame).unwrap();
299 
300         let frame = Frame::pong(vec![0x01]);
301         sock.send(frame).unwrap();
302 
303         let (buf, _) = sock.into_inner();
304         assert_eq!(buf, vec![0x89, 0x02, 0x04, 0x05, 0x8a, 0x01, 0x01]);
305     }
306 
307     #[test]
parse_overflow()308     fn parse_overflow() {
309         let raw = Cursor::new(vec![
310             0x83, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00,
311         ]);
312         let mut sock = FrameSocket::new(raw);
313         let _ = sock.read(None); // should not crash
314     }
315 
316     #[test]
size_limit_hit()317     fn size_limit_hit() {
318         let raw = Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
319         let mut sock = FrameSocket::new(raw);
320         assert!(matches!(
321             sock.read(Some(5)),
322             Err(Error::Capacity(CapacityError::MessageTooLong { size: 7, max_size: 5 }))
323         ));
324     }
325 }
326