1 use std::{io::Read, marker::PhantomData};
2 
3 use log::trace;
4 use serde::de::{self, Unexpected};
5 use serde::forward_to_deserialize_any;
6 use xml::name::OwnedName;
7 use xml::reader::{EventReader, ParserConfig, XmlEvent};
8 
9 use self::buffer::{BufferedXmlReader, ChildXmlBuffer, RootXmlBuffer};
10 use self::map::MapAccess;
11 use self::seq::SeqAccess;
12 use self::var::EnumAccess;
13 use crate::error::{Error, Result};
14 use crate::{debug_expect, expect};
15 
16 mod buffer;
17 mod map;
18 mod seq;
19 mod var;
20 
21 /// A convenience method for deserialize some object from a string.
22 ///
23 /// ```rust
24 /// # use serde::{Deserialize, Serialize};
25 /// # use serde_xml_rs::from_str;
26 /// #[derive(Debug, Deserialize, PartialEq)]
27 /// struct Item {
28 ///     name: String,
29 ///     source: String,
30 /// }
31 /// # fn main() {
32 /// let s = r##"<item name="hello" source="world.rs" />"##;
33 /// let item: Item = from_str(s).unwrap();
34 /// assert_eq!(item, Item { name: "hello".to_string(),source: "world.rs".to_string()});
35 /// # }
36 /// ```
from_str<'de, T: de::Deserialize<'de>>(s: &str) -> Result<T>37 pub fn from_str<'de, T: de::Deserialize<'de>>(s: &str) -> Result<T> {
38     from_reader(s.as_bytes())
39 }
40 
41 /// A convenience method for deserialize some object from a reader.
42 ///
43 /// ```rust
44 /// # use serde::Deserialize;
45 /// # use serde_xml_rs::from_reader;
46 /// #[derive(Debug, Deserialize, PartialEq)]
47 /// struct Item {
48 ///     name: String,
49 ///     source: String,
50 /// }
51 /// # fn main() {
52 /// let s = r##"<item name="hello" source="world.rs" />"##;
53 /// let item: Item = from_reader(s.as_bytes()).unwrap();
54 /// assert_eq!(item, Item { name: "hello".to_string(),source: "world.rs".to_string()});
55 /// # }
56 /// ```
from_reader<'de, R: Read, T: de::Deserialize<'de>>(reader: R) -> Result<T>57 pub fn from_reader<'de, R: Read, T: de::Deserialize<'de>>(reader: R) -> Result<T> {
58     T::deserialize(&mut Deserializer::new_from_reader(reader))
59 }
60 
61 type RootDeserializer<R> = Deserializer<R, RootXmlBuffer<R>>;
62 type ChildDeserializer<'parent, R> = Deserializer<R, ChildXmlBuffer<'parent, R>>;
63 
64 pub struct Deserializer<
65     R: Read, // Kept as type param to avoid type signature breaking-change
66     B: BufferedXmlReader<R> = RootXmlBuffer<R>,
67 > {
68     /// XML document nested element depth
69     depth: usize,
70     buffered_reader: B,
71     is_map_value: bool,
72     non_contiguous_seq_elements: bool,
73     marker: PhantomData<R>,
74 }
75 
76 impl<'de, R: Read> RootDeserializer<R> {
new(reader: EventReader<R>) -> Self77     pub fn new(reader: EventReader<R>) -> Self {
78         let buffered_reader = RootXmlBuffer::new(reader);
79 
80         Deserializer {
81             buffered_reader,
82             depth: 0,
83             is_map_value: false,
84             non_contiguous_seq_elements: false,
85             marker: PhantomData,
86         }
87     }
88 
new_from_reader(reader: R) -> Self89     pub fn new_from_reader(reader: R) -> Self {
90         let config = ParserConfig::new()
91             .trim_whitespace(true)
92             .whitespace_to_characters(true)
93             .cdata_to_characters(true)
94             .ignore_comments(true)
95             .coalesce_characters(true);
96 
97         Self::new(EventReader::new_with_config(reader, config))
98     }
99 
100     /// Configures whether the deserializer should search all sibling elements when building a
101     /// sequence. Not required if all XML elements for sequences are adjacent. Disabled by
102     /// default. Enabling this option may incur additional memory usage.
103     ///
104     /// ```rust
105     /// # use serde::Deserialize;
106     /// # use serde_xml_rs::from_reader;
107     /// #[derive(Debug, Deserialize, PartialEq)]
108     /// struct Foo {
109     ///     bar: Vec<usize>,
110     ///     baz: String,
111     /// }
112     /// # fn main() {
113     /// let s = r##"
114     ///     <foo>
115     ///         <bar>1</bar>
116     ///         <bar>2</bar>
117     ///         <baz>Hello, world</baz>
118     ///         <bar>3</bar>
119     ///         <bar>4</bar>
120     ///     </foo>
121     /// "##;
122     /// let mut de = serde_xml_rs::Deserializer::new_from_reader(s.as_bytes())
123     ///     .non_contiguous_seq_elements(true);
124     /// let foo = Foo::deserialize(&mut de).unwrap();
125     /// assert_eq!(foo, Foo { bar: vec![1, 2, 3, 4], baz: "Hello, world".to_string()});
126     /// # }
127     /// ```
non_contiguous_seq_elements(mut self, set: bool) -> Self128     pub fn non_contiguous_seq_elements(mut self, set: bool) -> Self {
129         self.non_contiguous_seq_elements = set;
130         self
131     }
132 }
133 
134 impl<'de, R: Read, B: BufferedXmlReader<R>> Deserializer<R, B> {
child<'a>(&'a mut self) -> Deserializer<R, ChildXmlBuffer<'a, R>>135     fn child<'a>(&'a mut self) -> Deserializer<R, ChildXmlBuffer<'a, R>> {
136         let Deserializer {
137             buffered_reader,
138             depth,
139             is_map_value,
140             non_contiguous_seq_elements,
141             ..
142         } = self;
143 
144         Deserializer {
145             buffered_reader: buffered_reader.child_buffer(),
146             depth: *depth,
147             is_map_value: *is_map_value,
148             non_contiguous_seq_elements: *non_contiguous_seq_elements,
149             marker: PhantomData,
150         }
151     }
152 
153     /// Gets the next XML event without advancing the cursor.
peek(&mut self) -> Result<&XmlEvent>154     fn peek(&mut self) -> Result<&XmlEvent> {
155         let peeked = self.buffered_reader.peek()?;
156 
157         trace!("Peeked {:?}", peeked);
158         Ok(peeked)
159     }
160 
161     /// Gets the XML event at the cursor and advances the cursor.
next(&mut self) -> Result<XmlEvent>162     fn next(&mut self) -> Result<XmlEvent> {
163         let next = self.buffered_reader.next()?;
164 
165         match next {
166             XmlEvent::StartElement { .. } => {
167                 self.depth += 1;
168             }
169             XmlEvent::EndElement { .. } => {
170                 self.depth -= 1;
171             }
172             _ => {}
173         }
174         trace!("Fetched {:?}", next);
175         Ok(next)
176     }
177 
set_map_value(&mut self)178     fn set_map_value(&mut self) {
179         self.is_map_value = true;
180     }
181 
unset_map_value(&mut self) -> bool182     pub fn unset_map_value(&mut self) -> bool {
183         ::std::mem::replace(&mut self.is_map_value, false)
184     }
185 
186     /// If `self.is_map_value`: Performs the read operations specified by `f` on the inner content of an XML element.
187     /// `f` is expected to consume the entire inner contents of the element. The cursor will be moved to the end of the
188     /// element.
189     /// If `!self.is_map_value`: `f` will be performed without additional checks/advances for an outer XML element.
read_inner_value<V: de::Visitor<'de>, T, F: FnOnce(&mut Self) -> Result<T>>( &mut self, f: F, ) -> Result<T>190     fn read_inner_value<V: de::Visitor<'de>, T, F: FnOnce(&mut Self) -> Result<T>>(
191         &mut self,
192         f: F,
193     ) -> Result<T> {
194         if self.unset_map_value() {
195             debug_expect!(self.next(), Ok(XmlEvent::StartElement { name, .. }) => {
196                 let result = f(self)?;
197                 self.expect_end_element(name)?;
198                 Ok(result)
199             })
200         } else {
201             f(self)
202         }
203     }
204 
expect_end_element(&mut self, start_name: OwnedName) -> Result<()>205     fn expect_end_element(&mut self, start_name: OwnedName) -> Result<()> {
206         expect!(self.next()?, XmlEvent::EndElement { name, .. } => {
207             if name == start_name {
208                 Ok(())
209             } else {
210                 Err(Error::Custom { field: format!(
211                     "End tag </{}> didn't match the start tag <{}>",
212                     name.local_name,
213                     start_name.local_name
214                 ) })
215             }
216         })
217     }
218 
prepare_parse_type<V: de::Visitor<'de>>(&mut self) -> Result<String>219     fn prepare_parse_type<V: de::Visitor<'de>>(&mut self) -> Result<String> {
220         if let XmlEvent::StartElement { .. } = *self.peek()? {
221             self.set_map_value()
222         }
223         self.read_inner_value::<V, String, _>(|this| {
224             if let XmlEvent::EndElement { .. } = *this.peek()? {
225                 return Err(Error::UnexpectedToken {
226                     token: "EndElement".into(),
227                     found: "Characters".into(),
228                 });
229             }
230 
231             expect!(this.next()?, XmlEvent::Characters(s) => {
232                 return Ok(s)
233             })
234         })
235     }
236 }
237 
238 macro_rules! deserialize_type {
239     ($deserialize:ident => $visit:ident) => {
240         fn $deserialize<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
241             let value = self.prepare_parse_type::<V>()?.parse()?;
242             visitor.$visit(value)
243         }
244     };
245 }
246 
247 impl<'de, 'a, R: Read, B: BufferedXmlReader<R>> de::Deserializer<'de>
248     for &'a mut Deserializer<R, B>
249 {
250     type Error = Error;
251 
252     forward_to_deserialize_any! {
253         identifier
254     }
255 
deserialize_struct<V: de::Visitor<'de>>( self, _name: &'static str, fields: &'static [&'static str], visitor: V, ) -> Result<V::Value>256     fn deserialize_struct<V: de::Visitor<'de>>(
257         self,
258         _name: &'static str,
259         fields: &'static [&'static str],
260         visitor: V,
261     ) -> Result<V::Value> {
262         self.unset_map_value();
263         expect!(self.next()?, XmlEvent::StartElement { name, attributes, .. } => {
264             let map_value = visitor.visit_map(MapAccess::new(
265                 self,
266                 attributes,
267                 fields.contains(&"$value")
268             ))?;
269             self.expect_end_element(name)?;
270             Ok(map_value)
271         })
272     }
273 
274     deserialize_type!(deserialize_i8 => visit_i8);
275     deserialize_type!(deserialize_i16 => visit_i16);
276     deserialize_type!(deserialize_i32 => visit_i32);
277     deserialize_type!(deserialize_i64 => visit_i64);
278     deserialize_type!(deserialize_u8 => visit_u8);
279     deserialize_type!(deserialize_u16 => visit_u16);
280     deserialize_type!(deserialize_u32 => visit_u32);
281     deserialize_type!(deserialize_u64 => visit_u64);
282     deserialize_type!(deserialize_f32 => visit_f32);
283     deserialize_type!(deserialize_f64 => visit_f64);
284 
deserialize_bool<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value>285     fn deserialize_bool<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
286         if let XmlEvent::StartElement { .. } = *self.peek()? {
287             self.set_map_value()
288         }
289         self.read_inner_value::<V, V::Value, _>(|this| {
290             if let XmlEvent::EndElement { .. } = *this.peek()? {
291                 return visitor.visit_bool(false);
292             }
293             expect!(this.next()?, XmlEvent::Characters(s) => {
294                 match s.as_str() {
295                     "true" | "1" => visitor.visit_bool(true),
296                     "false" | "0" => visitor.visit_bool(false),
297                     _ => Err(de::Error::invalid_value(Unexpected::Str(&s), &"a boolean")),
298                 }
299 
300             })
301         })
302     }
303 
deserialize_char<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value>304     fn deserialize_char<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
305         self.deserialize_string(visitor)
306     }
307 
deserialize_str<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value>308     fn deserialize_str<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
309         self.deserialize_string(visitor)
310     }
311 
deserialize_bytes<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value>312     fn deserialize_bytes<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
313         self.deserialize_string(visitor)
314     }
315 
deserialize_byte_buf<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value>316     fn deserialize_byte_buf<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
317         self.deserialize_string(visitor)
318     }
319 
deserialize_unit<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value>320     fn deserialize_unit<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
321         if let XmlEvent::StartElement { .. } = *self.peek()? {
322             self.set_map_value()
323         }
324         self.read_inner_value::<V, V::Value, _>(
325             |this| expect!(this.peek()?, &XmlEvent::EndElement { .. } => visitor.visit_unit()),
326         )
327     }
328 
deserialize_unit_struct<V: de::Visitor<'de>>( self, _name: &'static str, visitor: V, ) -> Result<V::Value>329     fn deserialize_unit_struct<V: de::Visitor<'de>>(
330         self,
331         _name: &'static str,
332         visitor: V,
333     ) -> Result<V::Value> {
334         self.deserialize_unit(visitor)
335     }
336 
deserialize_newtype_struct<V: de::Visitor<'de>>( self, _name: &'static str, visitor: V, ) -> Result<V::Value>337     fn deserialize_newtype_struct<V: de::Visitor<'de>>(
338         self,
339         _name: &'static str,
340         visitor: V,
341     ) -> Result<V::Value> {
342         visitor.visit_newtype_struct(self)
343     }
344 
deserialize_tuple_struct<V: de::Visitor<'de>>( self, _name: &'static str, len: usize, visitor: V, ) -> Result<V::Value>345     fn deserialize_tuple_struct<V: de::Visitor<'de>>(
346         self,
347         _name: &'static str,
348         len: usize,
349         visitor: V,
350     ) -> Result<V::Value> {
351         self.deserialize_tuple(len, visitor)
352     }
353 
deserialize_tuple<V: de::Visitor<'de>>(self, len: usize, visitor: V) -> Result<V::Value>354     fn deserialize_tuple<V: de::Visitor<'de>>(self, len: usize, visitor: V) -> Result<V::Value> {
355         let child_deserializer = self.child();
356 
357         visitor.visit_seq(SeqAccess::new(child_deserializer, Some(len)))
358     }
359 
deserialize_enum<V: de::Visitor<'de>>( self, _name: &'static str, _variants: &'static [&'static str], visitor: V, ) -> Result<V::Value>360     fn deserialize_enum<V: de::Visitor<'de>>(
361         self,
362         _name: &'static str,
363         _variants: &'static [&'static str],
364         visitor: V,
365     ) -> Result<V::Value> {
366         self.read_inner_value::<V, V::Value, _>(|this| visitor.visit_enum(EnumAccess::new(this)))
367     }
368 
deserialize_string<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value>369     fn deserialize_string<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
370         if let XmlEvent::StartElement { .. } = *self.peek()? {
371             self.set_map_value()
372         }
373         self.read_inner_value::<V, V::Value, _>(|this| {
374             if let XmlEvent::EndElement { .. } = *this.peek()? {
375                 return visitor.visit_str("");
376             }
377             expect!(this.next()?, XmlEvent::Characters(s) => {
378                 visitor.visit_string(s)
379             })
380         })
381     }
382 
deserialize_seq<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value>383     fn deserialize_seq<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
384         let child_deserializer = self.child();
385 
386         visitor.visit_seq(SeqAccess::new(child_deserializer, None))
387     }
388 
deserialize_map<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value>389     fn deserialize_map<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
390         self.unset_map_value();
391         expect!(self.next()?, XmlEvent::StartElement { name, attributes, .. } => {
392             let map_value = visitor.visit_map(MapAccess::new(self, attributes, false))?;
393             self.expect_end_element(name)?;
394             Ok(map_value)
395         })
396     }
397 
deserialize_option<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value>398     fn deserialize_option<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
399         match *self.peek()? {
400             XmlEvent::EndElement { .. } => visitor.visit_none(),
401             _ => visitor.visit_some(self),
402         }
403     }
404 
deserialize_ignored_any<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value>405     fn deserialize_ignored_any<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
406         self.unset_map_value();
407         let depth = self.depth;
408         loop {
409             self.next()?;
410             if self.depth == depth {
411                 break;
412             }
413         }
414         visitor.visit_unit()
415     }
416 
deserialize_any<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value>417     fn deserialize_any<V: de::Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
418         match *self.peek()? {
419             XmlEvent::StartElement { .. } => self.deserialize_map(visitor),
420             XmlEvent::EndElement { .. } => self.deserialize_unit(visitor),
421             _ => self.deserialize_string(visitor),
422         }
423     }
424 }
425