1 use std::{
2     convert::{AsRef, From, Into, TryFrom},
3     fmt,
4     result::Result as StdResult,
5     str,
6 };
7 
8 use super::frame::{CloseFrame, Frame};
9 use crate::error::{CapacityError, Error, Result};
10 
11 mod string_collect {
12     use utf8::DecodeError;
13 
14     use crate::error::{Error, Result};
15 
16     #[derive(Debug)]
17     pub struct StringCollector {
18         data: String,
19         incomplete: Option<utf8::Incomplete>,
20     }
21 
22     impl StringCollector {
new() -> Self23         pub fn new() -> Self {
24             StringCollector { data: String::new(), incomplete: None }
25         }
26 
len(&self) -> usize27         pub fn len(&self) -> usize {
28             self.data
29                 .len()
30                 .saturating_add(self.incomplete.map(|i| i.buffer_len as usize).unwrap_or(0))
31         }
32 
extend<T: AsRef<[u8]>>(&mut self, tail: T) -> Result<()>33         pub fn extend<T: AsRef<[u8]>>(&mut self, tail: T) -> Result<()> {
34             let mut input: &[u8] = tail.as_ref();
35 
36             if let Some(mut incomplete) = self.incomplete.take() {
37                 if let Some((result, rest)) = incomplete.try_complete(input) {
38                     input = rest;
39                     if let Ok(text) = result {
40                         self.data.push_str(text);
41                     } else {
42                         return Err(Error::Utf8);
43                     }
44                 } else {
45                     input = &[];
46                     self.incomplete = Some(incomplete);
47                 }
48             }
49 
50             if !input.is_empty() {
51                 match utf8::decode(input) {
52                     Ok(text) => {
53                         self.data.push_str(text);
54                         Ok(())
55                     }
56                     Err(DecodeError::Incomplete { valid_prefix, incomplete_suffix }) => {
57                         self.data.push_str(valid_prefix);
58                         self.incomplete = Some(incomplete_suffix);
59                         Ok(())
60                     }
61                     Err(DecodeError::Invalid { valid_prefix, .. }) => {
62                         self.data.push_str(valid_prefix);
63                         Err(Error::Utf8)
64                     }
65                 }
66             } else {
67                 Ok(())
68             }
69         }
70 
into_string(self) -> Result<String>71         pub fn into_string(self) -> Result<String> {
72             if self.incomplete.is_some() {
73                 Err(Error::Utf8)
74             } else {
75                 Ok(self.data)
76             }
77         }
78     }
79 }
80 
81 use self::string_collect::StringCollector;
82 
83 /// A struct representing the incomplete message.
84 #[derive(Debug)]
85 pub struct IncompleteMessage {
86     collector: IncompleteMessageCollector,
87 }
88 
89 #[derive(Debug)]
90 enum IncompleteMessageCollector {
91     Text(StringCollector),
92     Binary(Vec<u8>),
93 }
94 
95 impl IncompleteMessage {
96     /// Create new.
new(message_type: IncompleteMessageType) -> Self97     pub fn new(message_type: IncompleteMessageType) -> Self {
98         IncompleteMessage {
99             collector: match message_type {
100                 IncompleteMessageType::Binary => IncompleteMessageCollector::Binary(Vec::new()),
101                 IncompleteMessageType::Text => {
102                     IncompleteMessageCollector::Text(StringCollector::new())
103                 }
104             },
105         }
106     }
107 
108     /// Get the current filled size of the buffer.
len(&self) -> usize109     pub fn len(&self) -> usize {
110         match self.collector {
111             IncompleteMessageCollector::Text(ref t) => t.len(),
112             IncompleteMessageCollector::Binary(ref b) => b.len(),
113         }
114     }
115 
116     /// Add more data to an existing message.
extend<T: AsRef<[u8]>>(&mut self, tail: T, size_limit: Option<usize>) -> Result<()>117     pub fn extend<T: AsRef<[u8]>>(&mut self, tail: T, size_limit: Option<usize>) -> Result<()> {
118         // Always have a max size. This ensures an error in case of concatenating two buffers
119         // of more than `usize::max_value()` bytes in total.
120         let max_size = size_limit.unwrap_or_else(usize::max_value);
121         let my_size = self.len();
122         let portion_size = tail.as_ref().len();
123         // Be careful about integer overflows here.
124         if my_size > max_size || portion_size > max_size - my_size {
125             return Err(Error::Capacity(CapacityError::MessageTooLong {
126                 size: my_size + portion_size,
127                 max_size,
128             }));
129         }
130 
131         match self.collector {
132             IncompleteMessageCollector::Binary(ref mut v) => {
133                 v.extend(tail.as_ref());
134                 Ok(())
135             }
136             IncompleteMessageCollector::Text(ref mut t) => t.extend(tail),
137         }
138     }
139 
140     /// Convert an incomplete message into a complete one.
complete(self) -> Result<Message>141     pub fn complete(self) -> Result<Message> {
142         match self.collector {
143             IncompleteMessageCollector::Binary(v) => Ok(Message::Binary(v)),
144             IncompleteMessageCollector::Text(t) => {
145                 let text = t.into_string()?;
146                 Ok(Message::Text(text))
147             }
148         }
149     }
150 }
151 
152 /// The type of incomplete message.
153 pub enum IncompleteMessageType {
154     Text,
155     Binary,
156 }
157 
158 /// An enum representing the various forms of a WebSocket message.
159 #[derive(Debug, Eq, PartialEq, Clone)]
160 pub enum Message {
161     /// A text WebSocket message
162     Text(String),
163     /// A binary WebSocket message
164     Binary(Vec<u8>),
165     /// A ping message with the specified payload
166     ///
167     /// The payload here must have a length less than 125 bytes
168     Ping(Vec<u8>),
169     /// A pong message with the specified payload
170     ///
171     /// The payload here must have a length less than 125 bytes
172     Pong(Vec<u8>),
173     /// A close message with the optional close frame.
174     Close(Option<CloseFrame<'static>>),
175     /// Raw frame. Note, that you're not going to get this value while reading the message.
176     Frame(Frame),
177 }
178 
179 impl Message {
180     /// Create a new text WebSocket message from a stringable.
text<S>(string: S) -> Message where S: Into<String>,181     pub fn text<S>(string: S) -> Message
182     where
183         S: Into<String>,
184     {
185         Message::Text(string.into())
186     }
187 
188     /// Create a new binary WebSocket message by converting to `Vec<u8>`.
binary<B>(bin: B) -> Message where B: Into<Vec<u8>>,189     pub fn binary<B>(bin: B) -> Message
190     where
191         B: Into<Vec<u8>>,
192     {
193         Message::Binary(bin.into())
194     }
195 
196     /// Indicates whether a message is a text message.
is_text(&self) -> bool197     pub fn is_text(&self) -> bool {
198         matches!(*self, Message::Text(_))
199     }
200 
201     /// Indicates whether a message is a binary message.
is_binary(&self) -> bool202     pub fn is_binary(&self) -> bool {
203         matches!(*self, Message::Binary(_))
204     }
205 
206     /// Indicates whether a message is a ping message.
is_ping(&self) -> bool207     pub fn is_ping(&self) -> bool {
208         matches!(*self, Message::Ping(_))
209     }
210 
211     /// Indicates whether a message is a pong message.
is_pong(&self) -> bool212     pub fn is_pong(&self) -> bool {
213         matches!(*self, Message::Pong(_))
214     }
215 
216     /// Indicates whether a message is a close message.
is_close(&self) -> bool217     pub fn is_close(&self) -> bool {
218         matches!(*self, Message::Close(_))
219     }
220 
221     /// Get the length of the WebSocket message.
len(&self) -> usize222     pub fn len(&self) -> usize {
223         match *self {
224             Message::Text(ref string) => string.len(),
225             Message::Binary(ref data) | Message::Ping(ref data) | Message::Pong(ref data) => {
226                 data.len()
227             }
228             Message::Close(ref data) => data.as_ref().map(|d| d.reason.len()).unwrap_or(0),
229             Message::Frame(ref frame) => frame.len(),
230         }
231     }
232 
233     /// Returns true if the WebSocket message has no content.
234     /// For example, if the other side of the connection sent an empty string.
is_empty(&self) -> bool235     pub fn is_empty(&self) -> bool {
236         self.len() == 0
237     }
238 
239     /// Consume the WebSocket and return it as binary data.
into_data(self) -> Vec<u8>240     pub fn into_data(self) -> Vec<u8> {
241         match self {
242             Message::Text(string) => string.into_bytes(),
243             Message::Binary(data) | Message::Ping(data) | Message::Pong(data) => data,
244             Message::Close(None) => Vec::new(),
245             Message::Close(Some(frame)) => frame.reason.into_owned().into_bytes(),
246             Message::Frame(frame) => frame.into_data(),
247         }
248     }
249 
250     /// Attempt to consume the WebSocket message and convert it to a String.
into_text(self) -> Result<String>251     pub fn into_text(self) -> Result<String> {
252         match self {
253             Message::Text(string) => Ok(string),
254             Message::Binary(data) | Message::Ping(data) | Message::Pong(data) => {
255                 Ok(String::from_utf8(data)?)
256             }
257             Message::Close(None) => Ok(String::new()),
258             Message::Close(Some(frame)) => Ok(frame.reason.into_owned()),
259             Message::Frame(frame) => Ok(frame.into_string()?),
260         }
261     }
262 
263     /// Attempt to get a &str from the WebSocket message,
264     /// this will try to convert binary data to utf8.
to_text(&self) -> Result<&str>265     pub fn to_text(&self) -> Result<&str> {
266         match *self {
267             Message::Text(ref string) => Ok(string),
268             Message::Binary(ref data) | Message::Ping(ref data) | Message::Pong(ref data) => {
269                 Ok(str::from_utf8(data)?)
270             }
271             Message::Close(None) => Ok(""),
272             Message::Close(Some(ref frame)) => Ok(&frame.reason),
273             Message::Frame(ref frame) => Ok(frame.to_text()?),
274         }
275     }
276 }
277 
278 impl From<String> for Message {
from(string: String) -> Self279     fn from(string: String) -> Self {
280         Message::text(string)
281     }
282 }
283 
284 impl<'s> From<&'s str> for Message {
from(string: &'s str) -> Self285     fn from(string: &'s str) -> Self {
286         Message::text(string)
287     }
288 }
289 
290 impl<'b> From<&'b [u8]> for Message {
from(data: &'b [u8]) -> Self291     fn from(data: &'b [u8]) -> Self {
292         Message::binary(data)
293     }
294 }
295 
296 impl From<Vec<u8>> for Message {
from(data: Vec<u8>) -> Self297     fn from(data: Vec<u8>) -> Self {
298         Message::binary(data)
299     }
300 }
301 
302 impl From<Message> for Vec<u8> {
from(message: Message) -> Self303     fn from(message: Message) -> Self {
304         message.into_data()
305     }
306 }
307 
308 impl TryFrom<Message> for String {
309     type Error = Error;
310 
try_from(value: Message) -> StdResult<Self, Self::Error>311     fn try_from(value: Message) -> StdResult<Self, Self::Error> {
312         value.into_text()
313     }
314 }
315 
316 impl fmt::Display for Message {
fmt(&self, f: &mut fmt::Formatter) -> StdResult<(), fmt::Error>317     fn fmt(&self, f: &mut fmt::Formatter) -> StdResult<(), fmt::Error> {
318         if let Ok(string) = self.to_text() {
319             write!(f, "{}", string)
320         } else {
321             write!(f, "Binary Data<length={}>", self.len())
322         }
323     }
324 }
325 
326 #[cfg(test)]
327 mod tests {
328     use super::*;
329 
330     #[test]
display()331     fn display() {
332         let t = Message::text("test".to_owned());
333         assert_eq!(t.to_string(), "test".to_owned());
334 
335         let bin = Message::binary(vec![0, 1, 3, 4, 241]);
336         assert_eq!(bin.to_string(), "Binary Data<length=5>".to_owned());
337     }
338 
339     #[test]
binary_convert()340     fn binary_convert() {
341         let bin = [6u8, 7, 8, 9, 10, 241];
342         let msg = Message::from(&bin[..]);
343         assert!(msg.is_binary());
344         assert!(msg.into_text().is_err());
345     }
346 
347     #[test]
binary_convert_vec()348     fn binary_convert_vec() {
349         let bin = vec![6u8, 7, 8, 9, 10, 241];
350         let msg = Message::from(bin);
351         assert!(msg.is_binary());
352         assert!(msg.into_text().is_err());
353     }
354 
355     #[test]
binary_convert_into_vec()356     fn binary_convert_into_vec() {
357         let bin = vec![6u8, 7, 8, 9, 10, 241];
358         let bin_copy = bin.clone();
359         let msg = Message::from(bin);
360         let serialized: Vec<u8> = msg.into();
361         assert_eq!(bin_copy, serialized);
362     }
363 
364     #[test]
text_convert()365     fn text_convert() {
366         let s = "kiwotsukete";
367         let msg = Message::from(s);
368         assert!(msg.is_text());
369     }
370 }
371