1 #![cfg(feature = "serde")]
2 
3 use serde::de::{
4     Deserialize, DeserializeSeed, Deserializer, EnumAccess, Error, Unexpected, VariantAccess,
5     Visitor,
6 };
7 use serde::ser::{Serialize, Serializer};
8 
9 use crate::{Level, LevelFilter, LOG_LEVEL_NAMES};
10 
11 use std::fmt;
12 use std::str::{self, FromStr};
13 
14 // The Deserialize impls are handwritten to be case insensitive using FromStr.
15 
16 impl Serialize for Level {
serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> where S: Serializer,17     fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
18     where
19         S: Serializer,
20     {
21         match *self {
22             Level::Error => serializer.serialize_unit_variant("Level", 0, "ERROR"),
23             Level::Warn => serializer.serialize_unit_variant("Level", 1, "WARN"),
24             Level::Info => serializer.serialize_unit_variant("Level", 2, "INFO"),
25             Level::Debug => serializer.serialize_unit_variant("Level", 3, "DEBUG"),
26             Level::Trace => serializer.serialize_unit_variant("Level", 4, "TRACE"),
27         }
28     }
29 }
30 
31 impl<'de> Deserialize<'de> for Level {
deserialize<D>(deserializer: D) -> Result<Self, D::Error> where D: Deserializer<'de>,32     fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
33     where
34         D: Deserializer<'de>,
35     {
36         struct LevelIdentifier;
37 
38         impl<'de> Visitor<'de> for LevelIdentifier {
39             type Value = Level;
40 
41             fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
42                 formatter.write_str("log level")
43             }
44 
45             fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
46             where
47                 E: Error,
48             {
49                 let variant = LOG_LEVEL_NAMES[1..]
50                     .get(v as usize)
51                     .ok_or_else(|| Error::invalid_value(Unexpected::Unsigned(v), &self))?;
52 
53                 self.visit_str(variant)
54             }
55 
56             fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
57             where
58                 E: Error,
59             {
60                 // Case insensitive.
61                 FromStr::from_str(s).map_err(|_| Error::unknown_variant(s, &LOG_LEVEL_NAMES[1..]))
62             }
63 
64             fn visit_bytes<E>(self, value: &[u8]) -> Result<Self::Value, E>
65             where
66                 E: Error,
67             {
68                 let variant = str::from_utf8(value)
69                     .map_err(|_| Error::invalid_value(Unexpected::Bytes(value), &self))?;
70 
71                 self.visit_str(variant)
72             }
73         }
74 
75         impl<'de> DeserializeSeed<'de> for LevelIdentifier {
76             type Value = Level;
77 
78             fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
79             where
80                 D: Deserializer<'de>,
81             {
82                 deserializer.deserialize_identifier(LevelIdentifier)
83             }
84         }
85 
86         struct LevelEnum;
87 
88         impl<'de> Visitor<'de> for LevelEnum {
89             type Value = Level;
90 
91             fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
92                 formatter.write_str("log level")
93             }
94 
95             fn visit_enum<A>(self, value: A) -> Result<Self::Value, A::Error>
96             where
97                 A: EnumAccess<'de>,
98             {
99                 let (level, variant) = value.variant_seed(LevelIdentifier)?;
100                 // Every variant is a unit variant.
101                 variant.unit_variant()?;
102                 Ok(level)
103             }
104         }
105 
106         deserializer.deserialize_enum("Level", &LOG_LEVEL_NAMES[1..], LevelEnum)
107     }
108 }
109 
110 impl Serialize for LevelFilter {
serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> where S: Serializer,111     fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
112     where
113         S: Serializer,
114     {
115         match *self {
116             LevelFilter::Off => serializer.serialize_unit_variant("LevelFilter", 0, "OFF"),
117             LevelFilter::Error => serializer.serialize_unit_variant("LevelFilter", 1, "ERROR"),
118             LevelFilter::Warn => serializer.serialize_unit_variant("LevelFilter", 2, "WARN"),
119             LevelFilter::Info => serializer.serialize_unit_variant("LevelFilter", 3, "INFO"),
120             LevelFilter::Debug => serializer.serialize_unit_variant("LevelFilter", 4, "DEBUG"),
121             LevelFilter::Trace => serializer.serialize_unit_variant("LevelFilter", 5, "TRACE"),
122         }
123     }
124 }
125 
126 impl<'de> Deserialize<'de> for LevelFilter {
deserialize<D>(deserializer: D) -> Result<Self, D::Error> where D: Deserializer<'de>,127     fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
128     where
129         D: Deserializer<'de>,
130     {
131         struct LevelFilterIdentifier;
132 
133         impl<'de> Visitor<'de> for LevelFilterIdentifier {
134             type Value = LevelFilter;
135 
136             fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
137                 formatter.write_str("log level filter")
138             }
139 
140             fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
141             where
142                 E: Error,
143             {
144                 let variant = LOG_LEVEL_NAMES
145                     .get(v as usize)
146                     .ok_or_else(|| Error::invalid_value(Unexpected::Unsigned(v), &self))?;
147 
148                 self.visit_str(variant)
149             }
150 
151             fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
152             where
153                 E: Error,
154             {
155                 // Case insensitive.
156                 FromStr::from_str(s).map_err(|_| Error::unknown_variant(s, &LOG_LEVEL_NAMES))
157             }
158 
159             fn visit_bytes<E>(self, value: &[u8]) -> Result<Self::Value, E>
160             where
161                 E: Error,
162             {
163                 let variant = str::from_utf8(value)
164                     .map_err(|_| Error::invalid_value(Unexpected::Bytes(value), &self))?;
165 
166                 self.visit_str(variant)
167             }
168         }
169 
170         impl<'de> DeserializeSeed<'de> for LevelFilterIdentifier {
171             type Value = LevelFilter;
172 
173             fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
174             where
175                 D: Deserializer<'de>,
176             {
177                 deserializer.deserialize_identifier(LevelFilterIdentifier)
178             }
179         }
180 
181         struct LevelFilterEnum;
182 
183         impl<'de> Visitor<'de> for LevelFilterEnum {
184             type Value = LevelFilter;
185 
186             fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
187                 formatter.write_str("log level filter")
188             }
189 
190             fn visit_enum<A>(self, value: A) -> Result<Self::Value, A::Error>
191             where
192                 A: EnumAccess<'de>,
193             {
194                 let (level_filter, variant) = value.variant_seed(LevelFilterIdentifier)?;
195                 // Every variant is a unit variant.
196                 variant.unit_variant()?;
197                 Ok(level_filter)
198             }
199         }
200 
201         deserializer.deserialize_enum("LevelFilter", &LOG_LEVEL_NAMES, LevelFilterEnum)
202     }
203 }
204 
205 #[cfg(test)]
206 mod tests {
207     use crate::{Level, LevelFilter};
208     use serde_test::{assert_de_tokens, assert_de_tokens_error, assert_tokens, Token};
209 
level_token(variant: &'static str) -> Token210     fn level_token(variant: &'static str) -> Token {
211         Token::UnitVariant {
212             name: "Level",
213             variant,
214         }
215     }
216 
level_bytes_tokens(variant: &'static [u8]) -> [Token; 3]217     fn level_bytes_tokens(variant: &'static [u8]) -> [Token; 3] {
218         [
219             Token::Enum { name: "Level" },
220             Token::Bytes(variant),
221             Token::Unit,
222         ]
223     }
224 
level_variant_tokens(variant: u32) -> [Token; 3]225     fn level_variant_tokens(variant: u32) -> [Token; 3] {
226         [
227             Token::Enum { name: "Level" },
228             Token::U32(variant),
229             Token::Unit,
230         ]
231     }
232 
level_filter_token(variant: &'static str) -> Token233     fn level_filter_token(variant: &'static str) -> Token {
234         Token::UnitVariant {
235             name: "LevelFilter",
236             variant,
237         }
238     }
239 
level_filter_bytes_tokens(variant: &'static [u8]) -> [Token; 3]240     fn level_filter_bytes_tokens(variant: &'static [u8]) -> [Token; 3] {
241         [
242             Token::Enum {
243                 name: "LevelFilter",
244             },
245             Token::Bytes(variant),
246             Token::Unit,
247         ]
248     }
249 
level_filter_variant_tokens(variant: u32) -> [Token; 3]250     fn level_filter_variant_tokens(variant: u32) -> [Token; 3] {
251         [
252             Token::Enum {
253                 name: "LevelFilter",
254             },
255             Token::U32(variant),
256             Token::Unit,
257         ]
258     }
259 
260     #[test]
test_level_ser_de()261     fn test_level_ser_de() {
262         let cases = &[
263             (Level::Error, [level_token("ERROR")]),
264             (Level::Warn, [level_token("WARN")]),
265             (Level::Info, [level_token("INFO")]),
266             (Level::Debug, [level_token("DEBUG")]),
267             (Level::Trace, [level_token("TRACE")]),
268         ];
269 
270         for (s, expected) in cases {
271             assert_tokens(s, expected);
272         }
273     }
274 
275     #[test]
test_level_case_insensitive()276     fn test_level_case_insensitive() {
277         let cases = &[
278             (Level::Error, [level_token("error")]),
279             (Level::Warn, [level_token("warn")]),
280             (Level::Info, [level_token("info")]),
281             (Level::Debug, [level_token("debug")]),
282             (Level::Trace, [level_token("trace")]),
283         ];
284 
285         for (s, expected) in cases {
286             assert_de_tokens(s, expected);
287         }
288     }
289 
290     #[test]
test_level_de_bytes()291     fn test_level_de_bytes() {
292         let cases = &[
293             (Level::Error, level_bytes_tokens(b"ERROR")),
294             (Level::Warn, level_bytes_tokens(b"WARN")),
295             (Level::Info, level_bytes_tokens(b"INFO")),
296             (Level::Debug, level_bytes_tokens(b"DEBUG")),
297             (Level::Trace, level_bytes_tokens(b"TRACE")),
298         ];
299 
300         for (value, tokens) in cases {
301             assert_de_tokens(value, tokens);
302         }
303     }
304 
305     #[test]
test_level_de_variant_index()306     fn test_level_de_variant_index() {
307         let cases = &[
308             (Level::Error, level_variant_tokens(0)),
309             (Level::Warn, level_variant_tokens(1)),
310             (Level::Info, level_variant_tokens(2)),
311             (Level::Debug, level_variant_tokens(3)),
312             (Level::Trace, level_variant_tokens(4)),
313         ];
314 
315         for (value, tokens) in cases {
316             assert_de_tokens(value, tokens);
317         }
318     }
319 
320     #[test]
test_level_de_error()321     fn test_level_de_error() {
322         let msg = "unknown variant `errorx`, expected one of \
323                    `ERROR`, `WARN`, `INFO`, `DEBUG`, `TRACE`";
324         assert_de_tokens_error::<Level>(&[level_token("errorx")], msg);
325     }
326 
327     #[test]
test_level_filter_ser_de()328     fn test_level_filter_ser_de() {
329         let cases = &[
330             (LevelFilter::Off, [level_filter_token("OFF")]),
331             (LevelFilter::Error, [level_filter_token("ERROR")]),
332             (LevelFilter::Warn, [level_filter_token("WARN")]),
333             (LevelFilter::Info, [level_filter_token("INFO")]),
334             (LevelFilter::Debug, [level_filter_token("DEBUG")]),
335             (LevelFilter::Trace, [level_filter_token("TRACE")]),
336         ];
337 
338         for (s, expected) in cases {
339             assert_tokens(s, expected);
340         }
341     }
342 
343     #[test]
test_level_filter_case_insensitive()344     fn test_level_filter_case_insensitive() {
345         let cases = &[
346             (LevelFilter::Off, [level_filter_token("off")]),
347             (LevelFilter::Error, [level_filter_token("error")]),
348             (LevelFilter::Warn, [level_filter_token("warn")]),
349             (LevelFilter::Info, [level_filter_token("info")]),
350             (LevelFilter::Debug, [level_filter_token("debug")]),
351             (LevelFilter::Trace, [level_filter_token("trace")]),
352         ];
353 
354         for (s, expected) in cases {
355             assert_de_tokens(s, expected);
356         }
357     }
358 
359     #[test]
test_level_filter_de_bytes()360     fn test_level_filter_de_bytes() {
361         let cases = &[
362             (LevelFilter::Off, level_filter_bytes_tokens(b"OFF")),
363             (LevelFilter::Error, level_filter_bytes_tokens(b"ERROR")),
364             (LevelFilter::Warn, level_filter_bytes_tokens(b"WARN")),
365             (LevelFilter::Info, level_filter_bytes_tokens(b"INFO")),
366             (LevelFilter::Debug, level_filter_bytes_tokens(b"DEBUG")),
367             (LevelFilter::Trace, level_filter_bytes_tokens(b"TRACE")),
368         ];
369 
370         for (value, tokens) in cases {
371             assert_de_tokens(value, tokens);
372         }
373     }
374 
375     #[test]
test_level_filter_de_variant_index()376     fn test_level_filter_de_variant_index() {
377         let cases = &[
378             (LevelFilter::Off, level_filter_variant_tokens(0)),
379             (LevelFilter::Error, level_filter_variant_tokens(1)),
380             (LevelFilter::Warn, level_filter_variant_tokens(2)),
381             (LevelFilter::Info, level_filter_variant_tokens(3)),
382             (LevelFilter::Debug, level_filter_variant_tokens(4)),
383             (LevelFilter::Trace, level_filter_variant_tokens(5)),
384         ];
385 
386         for (value, tokens) in cases {
387             assert_de_tokens(value, tokens);
388         }
389     }
390 
391     #[test]
test_level_filter_de_error()392     fn test_level_filter_de_error() {
393         let msg = "unknown variant `errorx`, expected one of \
394                    `OFF`, `ERROR`, `WARN`, `INFO`, `DEBUG`, `TRACE`";
395         assert_de_tokens_error::<LevelFilter>(&[level_filter_token("errorx")], msg);
396     }
397 }
398